"].replace("The image shows ", "")
+ if concept_sentence:
+ caption_text = f"{caption_text} [trigger]"
+ captions[i] = caption_text
+
+ yield captions
+ model.to("cpu")
+ del model
+ del processor
+
+def recursive_update(d, u):
+ for k, v in u.items():
+ if isinstance(v, dict) and v:
+ d[k] = recursive_update(d.get(k, {}), v)
+ else:
+ d[k] = v
+ return d
+
+def start_training(
+ lora_name,
+ concept_sentence,
+ steps,
+ lr,
+ rank,
+ model_to_train,
+ low_vram,
+ dataset_folder,
+ sample_1,
+ sample_2,
+ sample_3,
+ use_more_advanced_options,
+ more_advanced_options,
+):
+ push_to_hub = True
+ if not lora_name:
+ raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
+ try:
+ if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
+ gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.")
+ else:
+ push_to_hub = False
+ gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
+ except:
+ push_to_hub = False
+ gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
+
+ print("Started training")
+ slugged_lora_name = slugify(lora_name)
+
+ # Load the default config
+ with open("config/examples/train_lora_flux_24gb.yaml", "r") as f:
+ config = yaml.safe_load(f)
+
+ # Update the config with user inputs
+ config["config"]["name"] = slugged_lora_name
+ config["config"]["process"][0]["model"]["low_vram"] = low_vram
+ config["config"]["process"][0]["train"]["skip_first_sample"] = True
+ config["config"]["process"][0]["train"]["steps"] = int(steps)
+ config["config"]["process"][0]["train"]["lr"] = float(lr)
+ config["config"]["process"][0]["network"]["linear"] = int(rank)
+ config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
+ config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
+ if(push_to_hub):
+ try:
+ username = whoami()["name"]
+ except:
+ raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
+ config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
+ config["config"]["process"][0]["save"]["hf_private"] = True
+ if concept_sentence:
+ config["config"]["process"][0]["trigger_word"] = concept_sentence
+
+ if sample_1 or sample_2 or sample_3:
+ config["config"]["process"][0]["train"]["disable_sampling"] = False
+ config["config"]["process"][0]["sample"]["sample_every"] = steps
+ config["config"]["process"][0]["sample"]["sample_steps"] = 28
+ config["config"]["process"][0]["sample"]["prompts"] = []
+ if sample_1:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
+ if sample_2:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
+ if sample_3:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
+ else:
+ config["config"]["process"][0]["train"]["disable_sampling"] = True
+ if(model_to_train == "schnell"):
+ config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
+ config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
+ config["config"]["process"][0]["sample"]["sample_steps"] = 4
+ if(use_more_advanced_options):
+ more_advanced_options_dict = yaml.safe_load(more_advanced_options)
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
+ print(config)
+
+ # Save the updated config
+ # generate a random name for the config
+ random_config_name = str(uuid.uuid4())
+ os.makedirs("tmp", exist_ok=True)
+ config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+
+ # run the job locally
+ job = get_job(config_path)
+ job.run()
+ job.cleanup()
+
+ return f"Training completed successfully. Model saved as {slugged_lora_name}"
+
+config_yaml = '''
+device: cuda:0
+model:
+ is_flux: true
+ quantize: true
+network:
+ linear: 16 #it will overcome the 'rank' parameter
+ linear_alpha: 16 #you can have an alpha different than the ranking if you'd like
+ type: lora
+sample:
+ guidance_scale: 3.5
+ height: 1024
+ neg: '' #doesn't work for FLUX
+ sample_every: 1000
+ sample_steps: 28
+ sampler: flowmatch
+ seed: 42
+ walk_seed: true
+ width: 1024
+save:
+ dtype: float16
+ hf_private: true
+ max_step_saves_to_keep: 4
+ push_to_hub: true
+ save_every: 10000
+train:
+ batch_size: 1
+ dtype: bf16
+ ema_config:
+ ema_decay: 0.99
+ use_ema: true
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: true
+ noise_scheduler: flowmatch
+ optimizer: adamw8bit #options: prodigy, dadaptation, adamw, adamw8bit, lion, lion8bit
+ train_text_encoder: false #probably doesn't work for flux
+ train_unet: true
+'''
+
+theme = gr.themes.Monochrome(
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
+)
+css = """
+h1{font-size: 2em}
+h3{margin-top: 0}
+#component-1{text-align:center}
+.main_ui_logged_out{opacity: 0.3; pointer-events: none}
+.tabitem{border: 0px}
+.group_padding{padding: .55em}
+"""
+with gr.Blocks(theme=theme, css=css) as demo:
+ gr.Markdown(
+ """# LoRA Ease for FLUX 🧞♂️
+### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)"""
+ )
+ with gr.Column() as main_ui:
+ with gr.Row():
+ lora_name = gr.Textbox(
+ label="The name of your LoRA",
+ info="This has to be a unique name",
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
+ )
+ concept_sentence = gr.Textbox(
+ label="Trigger word/sentence",
+ info="Trigger word or sentence to be used",
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
+ interactive=True,
+ )
+ with gr.Group(visible=True) as image_upload:
+ with gr.Row():
+ images = gr.File(
+ file_types=["image", ".txt"],
+ label="Upload your images",
+ file_count="multiple",
+ interactive=True,
+ visible=True,
+ scale=1,
+ )
+ with gr.Column(scale=3, visible=False) as captioning_area:
+ with gr.Column():
+ gr.Markdown(
+ """# Custom captioning
+You can optionally add a custom caption for each image (or use an AI model for this). [trigger] will represent your concept sentence/trigger word.
+""", elem_classes="group_padding")
+ do_captioning = gr.Button("Add AI captions with Florence-2")
+ output_components = [captioning_area]
+ caption_list = []
+ for i in range(1, MAX_IMAGES + 1):
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
+ with locals()[f"captioning_row_{i}"]:
+ locals()[f"image_{i}"] = gr.Image(
+ type="filepath",
+ width=111,
+ height=111,
+ min_width=111,
+ interactive=False,
+ scale=2,
+ show_label=False,
+ show_share_button=False,
+ show_download_button=False,
+ )
+ locals()[f"caption_{i}"] = gr.Textbox(
+ label=f"Caption {i}", scale=15, interactive=True
+ )
+
+ output_components.append(locals()[f"captioning_row_{i}"])
+ output_components.append(locals()[f"image_{i}"])
+ output_components.append(locals()[f"caption_{i}"])
+ caption_list.append(locals()[f"caption_{i}"])
+
+ with gr.Accordion("Advanced options", open=False):
+ steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
+ lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
+ rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
+ model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train")
+ low_vram = gr.Checkbox(label="Low VRAM", value=True)
+ with gr.Accordion("Even more advanced options", open=False):
+ use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False)
+ more_advanced_options = gr.Code(config_yaml, language="yaml")
+
+ with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
+ gr.Markdown(
+ "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)"
+ )
+ sample_1 = gr.Textbox(label="Test prompt 1")
+ sample_2 = gr.Textbox(label="Test prompt 2")
+ sample_3 = gr.Textbox(label="Test prompt 3")
+
+ output_components.append(sample)
+ output_components.append(sample_1)
+ output_components.append(sample_2)
+ output_components.append(sample_3)
+ start = gr.Button("Start training", visible=False)
+ output_components.append(start)
+ progress_area = gr.Markdown("")
+
+ dataset_folder = gr.State()
+
+ images.upload(
+ load_captioning,
+ inputs=[images, concept_sentence],
+ outputs=output_components
+ )
+
+ images.delete(
+ load_captioning,
+ inputs=[images, concept_sentence],
+ outputs=output_components
+ )
+
+ images.clear(
+ hide_captioning,
+ outputs=[captioning_area, sample, start]
+ )
+
+ start.click(fn=create_dataset, inputs=[images] + caption_list, outputs=dataset_folder).then(
+ fn=start_training,
+ inputs=[
+ lora_name,
+ concept_sentence,
+ steps,
+ lr,
+ rank,
+ model_to_train,
+ low_vram,
+ dataset_folder,
+ sample_1,
+ sample_2,
+ sample_3,
+ use_more_advanced_options,
+ more_advanced_options
+ ],
+ outputs=progress_area,
+ )
+
+ do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
+
+if __name__ == "__main__":
+ demo.launch(share=True, show_error=True)
\ No newline at end of file
diff --git a/hf_ui.py b/hf_ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..df41f0236932484ea32c6f8d7a4a26f324481303
--- /dev/null
+++ b/hf_ui.py
@@ -0,0 +1,417 @@
+import os
+from huggingface_hub import whoami
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+import sys
+import subprocess
+import gradio as gr
+import uuid
+import os
+import shutil
+import json
+import yaml
+from slugify import slugify
+from run_modal_from_hf import main
+
+# Add the current working directory to the Python path
+sys.path.insert(0, os.getcwd())
+sys.path.insert(0, "ai-toolkit")
+
+MAX_IMAGES = 150
+
+# Import app và main trực tiếp từ run_modal_from_hf
+
+def create_dataset(*inputs):
+ print("Creating dataset")
+ files = inputs[0]
+ destination_folder = str(f"datasets/{uuid.uuid4()}")
+ os.makedirs(destination_folder, exist_ok=True)
+
+ if files is not None:
+ # Handle both single and multiple files
+ if not isinstance(files, list):
+ files = [files] # convert to a list if is not one
+
+ # Phân loại files
+ image_files = []
+ caption_files = []
+ zip_files = []
+
+ for file in files:
+ ext = os.path.splitext(file.name)[1].lower()
+ if ext in ['.jpg', '.jpeg', '.png']:
+ image_files.append(file)
+ elif ext == '.txt':
+ caption_files.append(file)
+ elif ext == '.zip':
+ zip_files.append(file)
+ else:
+ raise ValueError(f"Unsupported file type: {ext}")
+
+ # Validate số lượng files
+ if len(zip_files) > 1:
+ raise ValueError("Please upload only one zip file")
+
+ if zip_files and (image_files or caption_files):
+ raise ValueError("Please upload either a zip file OR individual files, not both")
+
+ # Copy files vào destination folder
+ for file in image_files + caption_files:
+ shutil.copy2(file.name, destination_folder)
+
+ # Nếu có zip file, chỉ copy zip file (sẽ được xử lý bên Modal)
+ if zip_files:
+ shutil.copy2(zip_files[0].name, destination_folder)
+
+ # Validate nếu là loose files
+ if image_files or caption_files:
+ validate_image_caption_pairs(destination_folder)
+
+ return destination_folder
+
+def validate_image_caption_pairs(folder_path):
+ """Validate images và captions nếu được upload riêng lẻ"""
+ images = []
+ captions = []
+
+ for file in os.listdir(folder_path):
+ name, ext = os.path.splitext(file)
+ ext = ext.lower()
+
+ if ext in ['.jpg', '.jpeg', '.png']:
+ images.append(name)
+ elif ext == '.txt':
+ captions.append(name)
+
+ # Kiểm tra nếu có caption thì phải match với images
+ if captions:
+ missing_captions = []
+ for img in images:
+ if img not in captions:
+ missing_captions.append(img)
+
+ if missing_captions:
+ raise ValueError(f"Missing captions for images: {', '.join(missing_captions)}")
+
+def recursive_update(d, u):
+ for k, v in u.items():
+ if isinstance(v, dict) and v:
+ d[k] = recursive_update(d.get(k, {}), v)
+ else:
+ d[k] = v
+ return d
+
+def start_training(
+ lora_name,
+ concept_sentence,
+ steps,
+ lr,
+ rank,
+ model_to_train,
+ low_vram,
+ dataset_folder,
+ sample_1,
+ sample_2,
+ sample_3,
+ use_more_advanced_options,
+ more_advanced_options,
+ push_to_hub,
+ use_wandb,
+):
+ print("Starting training from gradio app")
+
+ # build config
+ config = {
+ "job": "extension",
+ "config": {
+ "name": lora_name,
+ "process": [
+ {
+ "type": "sd_trainer",
+ }
+ ]
+ },
+ }
+ # build main config
+ config['config']['process'][0]['training_folder'] = "/root/ai-toolkit/modal_output"
+ config['config']['process'][0]['device'] = "cuda:0"
+ config['config']['process'][0]['network'] = {
+ "type": "lora",
+ "linear": int(rank),
+ "linear_alpha": int(rank)
+ }
+ config['config']['process'][0]['save'] = {
+ "dtype": "float16",
+ "save_every": int(steps),
+ "max_step_saves_to_keep": 4,
+ "push_to_hub": push_to_hub,
+ "hf_repo_id": f"test/{slugify(lora_name)}",
+ "hf_private": True
+ }
+
+ config['config']['process'][0]['datasets'] = [{
+ "folder_path": "/root/ai-toolkit/" + dataset_folder, # MUST match modal directory
+ "caption_ext": "txt",
+ "caption_dropout_rate": 0.05,
+ "shuffle_tokens": False,
+ "cache_latents_to_disk": True,
+ "resolution": [512, 768, 1024]
+ }]
+
+ config['config']['process'][0]['train'] = {
+ "batch_size": 1,
+ "steps": int(steps),
+ "gradient_accumulation_steps": 1,
+ "train_unet": True,
+ "train_text_encoder": False,
+ "gradient_checkpointing": True,
+ "noise_scheduler": "flowmatch",
+ "optimizer": "adamw8bit",
+ "lr": float(lr),
+ "dtype": "bf16",
+ "ema_config": {
+ "use_ema": True,
+ "ema_decay": 0.99
+ }
+ }
+
+ config['config']['process'][0]['model'] = {
+ "name_or_path": "black-forest-labs/FLUX.1-dev",
+ "is_flux": True,
+ "quantize": True,
+ "low_vram": low_vram
+ }
+ config['config']['process'][0]['sample'] = {
+ "sampler": "flowmatch",
+ "sample_every": int(steps),
+ "width": 1024,
+ "height": 1024,
+ "prompts": [
+ f"woman with red hair, playing chess at the park, bomb going off in the background {concept_sentence}",
+ f"a woman holding a coffee cup, in a beanie, sitting at a cafe {concept_sentence}",
+ f"a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini {concept_sentence}",
+ f"a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background {concept_sentence}",
+ f"a bear building a log cabin in the snow covered mountains {concept_sentence}",
+ f"woman playing the guitar, on stage, singing a song, laser lights, punk rocker {concept_sentence}",
+ f"hipster man with a beard, building a chair, in a wood shop {concept_sentence}",
+ f"photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop {concept_sentence}",
+ f"a man holding a sign that says, 'this is a sign' {concept_sentence}",
+ f"a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle {concept_sentence}"
+ ],
+ "neg": "",
+ "seed": 42,
+ "walk_seed": True,
+ "guidance_scale": 4,
+ "sample_steps": 20
+ }
+ if sample_1 or sample_2 or sample_3:
+ config['config']['process'][0]["sample"]['prompts'] = []
+ if sample_1:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_1)
+ if sample_2:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
+ if sample_3:
+ config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
+
+ if concept_sentence:
+ config['config']['process'][0]['trigger_word'] = concept_sentence
+
+ if(model_to_train == "schnell"):
+ config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
+ config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
+ config["config"]["process"][0]["sample"]["sample_steps"] = 4
+ config["config"]["process"][0]["sample"]["guidance_scale"] = 1
+
+ if(use_more_advanced_options):
+ more_advanced_options_dict = yaml.safe_load(more_advanced_options)
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
+
+ # add wandb if needed
+ config['config']['process'][0]['logging'] = {
+ "log_every": 10,
+ "use_wandb": use_wandb,
+ "verbose": False
+ }
+
+ # pass to modal function
+ config_file_list_str = json.dumps(config)
+
+ try:
+ main.remote(
+ config_file_list_str=config_file_list_str,
+ recover=True,
+ name=lora_name
+ )
+ return "Training started in Modal. Check your Modal dashboard for logs and status"
+ except Exception as e:
+ return f"Error starting training: {str(e)}"
+
+def setup_modal_token(token_command):
+ try:
+ # Tách command thành các phần
+ parts = token_command.strip().split()
+ if len(parts) == 6 and parts[0] == "modal" and parts[1] == "token" and parts[2] == "set":
+ token_id = parts[4]
+ token_secret = parts[6]
+
+ # Thực thi lệnh
+ result = subprocess.run(
+ ["modal", "token", "set", "--token-id", token_id, "--token-secret", token_secret],
+ capture_output=True,
+ text=True
+ )
+
+ if result.returncode == 0:
+ return "Modal token đã được cấu hình thành công!"
+ else:
+ return f"Lỗi khi cấu hình token: {result.stderr}"
+ except Exception as e:
+ return f"Lỗi: {str(e)}"
+
+config_yaml = '''
+device: cuda:0
+model:
+ is_flux: true
+ quantize: true
+network:
+ linear: 16 #it will overcome the 'rank' parameter
+ linear_alpha: 16 #you can have an alpha different than the ranking if you'd like
+ type: lora
+sample:
+ guidance_scale: 3.5
+ height: 1024
+ neg: '' #doesn't work for FLUX
+ sample_every: 1000
+ sample_steps: 28
+ sampler: flowmatch
+ seed: 42
+ walk_seed: true
+ width: 1024
+save:
+ dtype: float16
+ hf_private: true
+ max_step_saves_to_keep: 4
+ push_to_hub: true
+ save_every: 10000
+train:
+ batch_size: 1
+ dtype: bf16
+ ema_config:
+ ema_decay: 0.99
+ use_ema: true
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: true
+ noise_scheduler: flowmatch
+ optimizer: adamw8bit
+ train_text_encoder: false #probably doesn't work for flux
+ train_unet: true
+'''
+
+theme = gr.themes.Monochrome(
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
+)
+css = """
+h1{font-size: 2em}
+h3{margin-top: 0}
+#component-1{text-align:center}
+.main_ui_logged_out{opacity: 0.3; pointer-events: none}
+.tabitem{border: 0px}
+.group_padding{padding: .55em}
+"""
+with gr.Blocks(theme=theme, css=css) as demo:
+ gr.Markdown(
+ """# LoRA Ease for FLUX 🧞♂️
+### Train a high quality FLUX LoRA in a breeze ༄ using [Ostris' AI Toolkit](https://github.com/ostris/ai-toolkit)"""
+ )
+ with gr.Column() as main_ui:
+ with gr.Row():
+ lora_name = gr.Textbox(
+ label="The name of your LoRA",
+ info="This has to be a unique name",
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
+ )
+ concept_sentence = gr.Textbox(
+ label="Trigger word/sentence",
+ info="Trigger word or sentence to be used",
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
+ interactive=True,
+ )
+ with gr.Group(visible=True) as image_upload:
+ with gr.Row():
+ images = gr.File(
+ file_types=["image", ".txt", ".zip"],
+ label="Upload your dataset as zip or multiple images",
+ file_count="multiple",
+ interactive=True,
+ visible=True,
+ scale=1,
+ )
+
+ with gr.Accordion("Advanced options", open=False):
+ steps = gr.Number(label="Steps", value=1000, minimum=1, maximum=10000, step=1)
+ lr = gr.Number(label="Learning Rate", value=4e-4, minimum=1e-6, maximum=1e-3, step=1e-6)
+ rank = gr.Number(label="LoRA Rank", value=16, minimum=4, maximum=128, step=4)
+ model_to_train = gr.Radio(["dev", "schnell"], value="dev", label="Model to train")
+ low_vram = gr.Checkbox(label="Low VRAM", value=True)
+ with gr.Accordion("Even more advanced options", open=False):
+ use_more_advanced_options = gr.Checkbox(label="Use more advanced options", value=False)
+ more_advanced_options = gr.Code(config_yaml, language="yaml")
+
+ with gr.Accordion("Sample prompts (optional)", visible=False) as sample:
+ gr.Markdown(
+ "Include sample prompts to test out your trained model. Don't forget to include your trigger word/sentence (optional)"
+ )
+ sample_1 = gr.Textbox(label="Test prompt 1")
+ sample_2 = gr.Textbox(label="Test prompt 2")
+ sample_3 = gr.Textbox(label="Test prompt 3")
+
+ output_components = [sample, sample_1, sample_2, sample_3]
+
+ with gr.Row():
+ push_to_hub = gr.Checkbox(label="Push to Hub", value=True)
+ use_wandb = gr.Checkbox(label="Use WandB", value=False)
+ start = gr.Button("Start training")
+ output_components.append(start)
+ progress_area = gr.Markdown("")
+ output_components.append(progress_area)
+
+ dataset_folder = gr.State()
+
+ with gr.Accordion("Modal Configuration", open=False):
+ modal_token_input = gr.Textbox(
+ label="Nhập lệnh Modal token",
+ placeholder="modal token set --token-id YOUR_TOKEN_ID --token-secret YOUR_TOKEN_SECRET"
+ )
+ modal_setup_btn = gr.Button("Setup Modal Token")
+ modal_status = gr.Markdown("")
+
+ modal_setup_btn.click(
+ fn=setup_modal_token,
+ inputs=[modal_token_input],
+ outputs=[modal_status]
+ )
+
+ start.click(fn=create_dataset, inputs=[images], outputs=dataset_folder).then(
+ fn=start_training,
+ inputs=[
+ lora_name,
+ concept_sentence,
+ steps,
+ lr,
+ rank,
+ model_to_train,
+ low_vram,
+ dataset_folder,
+ sample_1,
+ sample_2,
+ sample_3,
+ use_more_advanced_options,
+ more_advanced_options,
+ push_to_hub,
+ use_wandb
+ ],
+ outputs=progress_area,
+ )
+
+if __name__ == "__main__":
+ demo.launch(share=True, show_error=True)
diff --git a/info.py b/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f2f0a97403deb778f0549c5fed2f9972ac75209
--- /dev/null
+++ b/info.py
@@ -0,0 +1,8 @@
+from collections import OrderedDict
+
+v = OrderedDict()
+v["name"] = "ai-toolkit"
+v["repo"] = "https://github.com/ostris/ai-toolkit"
+v["version"] = "0.1.0"
+
+software_meta = v
diff --git a/jobs/BaseJob.py b/jobs/BaseJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..8efd0097c6898cd8a6087fe9299f7e191f5a893a
--- /dev/null
+++ b/jobs/BaseJob.py
@@ -0,0 +1,72 @@
+import importlib
+from collections import OrderedDict
+from typing import List
+
+from jobs.process import BaseProcess
+
+
+class BaseJob:
+
+ def __init__(self, config: OrderedDict):
+ if not config:
+ raise ValueError('config is required')
+ self.process: List[BaseProcess]
+
+ self.config = config['config']
+ self.raw_config = config
+ self.job = config['job']
+ self.torch_profiler = self.get_conf('torch_profiler', False)
+ self.name = self.get_conf('name', required=True)
+ if 'meta' in config:
+ self.meta = config['meta']
+ else:
+ self.meta = OrderedDict()
+
+ def get_conf(self, key, default=None, required=False):
+ if key in self.config:
+ return self.config[key]
+ elif required:
+ raise ValueError(f'config file error. Missing "config.{key}" key')
+ else:
+ return default
+
+ def run(self):
+ print("")
+ print(f"#############################################")
+ print(f"# Running job: {self.name}")
+ print(f"#############################################")
+ print("")
+ # implement in child class
+ # be sure to call super().run() first
+ pass
+
+ def load_processes(self, process_dict: dict):
+ # only call if you have processes in this job type
+ if 'process' not in self.config:
+ raise ValueError('config file is invalid. Missing "config.process" key')
+ if len(self.config['process']) == 0:
+ raise ValueError('config file is invalid. "config.process" must be a list of processes')
+
+ module = importlib.import_module('jobs.process')
+
+ # add the processes
+ self.process = []
+ for i, process in enumerate(self.config['process']):
+ if 'type' not in process:
+ raise ValueError(f'config file is invalid. Missing "config.process[{i}].type" key')
+
+ # check if dict key is process type
+ if process['type'] in process_dict:
+ if isinstance(process_dict[process['type']], str):
+ ProcessClass = getattr(module, process_dict[process['type']])
+ else:
+ # it is the class
+ ProcessClass = process_dict[process['type']]
+ self.process.append(ProcessClass(i, self, process))
+ else:
+ raise ValueError(f'config file is invalid. Unknown process type: {process["type"]}')
+
+ def cleanup(self):
+ # if you implement this in child clas,
+ # be sure to call super().cleanup() LAST
+ del self
diff --git a/jobs/ExtensionJob.py b/jobs/ExtensionJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..def4f8530a8a92c65369cd63a3e69c16bf0bb7de
--- /dev/null
+++ b/jobs/ExtensionJob.py
@@ -0,0 +1,22 @@
+import os
+from collections import OrderedDict
+from jobs import BaseJob
+from toolkit.extension import get_all_extensions_process_dict
+from toolkit.paths import CONFIG_ROOT
+
+class ExtensionJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.device = self.get_conf('device', 'cpu')
+ self.process_dict = get_all_extensions_process_dict()
+ self.load_processes(self.process_dict)
+
+ def run(self):
+ super().run()
+
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/ExtractJob.py b/jobs/ExtractJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..d710d4128db5304569357ee05d2fb31fa15c6e39
--- /dev/null
+++ b/jobs/ExtractJob.py
@@ -0,0 +1,58 @@
+from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
+from collections import OrderedDict
+from jobs import BaseJob
+from toolkit.train_tools import get_torch_dtype
+
+process_dict = {
+ 'locon': 'ExtractLoconProcess',
+ 'lora': 'ExtractLoraProcess',
+}
+
+
+class ExtractJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.base_model_path = self.get_conf('base_model', required=True)
+ self.model_base = None
+ self.model_base_text_encoder = None
+ self.model_base_vae = None
+ self.model_base_unet = None
+ self.extract_model_path = self.get_conf('extract_model', required=True)
+ self.model_extract = None
+ self.model_extract_text_encoder = None
+ self.model_extract_vae = None
+ self.model_extract_unet = None
+ self.extract_unet = self.get_conf('extract_unet', True)
+ self.extract_text_encoder = self.get_conf('extract_text_encoder', True)
+ self.dtype = self.get_conf('dtype', 'fp16')
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ self.output_folder = self.get_conf('output_folder', required=True)
+ self.is_v2 = self.get_conf('is_v2', False)
+ self.device = self.get_conf('device', 'cpu')
+
+ # loads the processes from the config
+ self.load_processes(process_dict)
+
+ def run(self):
+ super().run()
+ # load models
+ print(f"Loading models for extraction")
+ print(f" - Loading base model: {self.base_model_path}")
+ # (text_model, vae, unet)
+ self.model_base = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.base_model_path)
+ self.model_base_text_encoder = self.model_base[0]
+ self.model_base_vae = self.model_base[1]
+ self.model_base_unet = self.model_base[2]
+
+ print(f" - Loading extract model: {self.extract_model_path}")
+ self.model_extract = load_models_from_stable_diffusion_checkpoint(self.is_v2, self.extract_model_path)
+ self.model_extract_text_encoder = self.model_extract[0]
+ self.model_extract_vae = self.model_extract[1]
+ self.model_extract_unet = self.model_extract[2]
+
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/GenerateJob.py b/jobs/GenerateJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab61701a3bc5e3a0f21e67c5aeee7c67b5e8f7c2
--- /dev/null
+++ b/jobs/GenerateJob.py
@@ -0,0 +1,31 @@
+from jobs import BaseJob
+from collections import OrderedDict
+from typing import List
+from jobs.process import GenerateProcess
+from toolkit.paths import REPOS_ROOT
+
+import sys
+
+sys.path.append(REPOS_ROOT)
+
+process_dict = {
+ 'to_folder': 'GenerateProcess',
+}
+
+
+class GenerateJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.device = self.get_conf('device', 'cpu')
+
+ # loads the processes from the config
+ self.load_processes(process_dict)
+
+ def run(self):
+ super().run()
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/MergeJob.py b/jobs/MergeJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e3b87b9ff589438d06c56019446f06efb76cda
--- /dev/null
+++ b/jobs/MergeJob.py
@@ -0,0 +1,29 @@
+from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
+from collections import OrderedDict
+from jobs import BaseJob
+from toolkit.train_tools import get_torch_dtype
+
+process_dict = {
+}
+
+
+class MergeJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.dtype = self.get_conf('dtype', 'fp16')
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ self.is_v2 = self.get_conf('is_v2', False)
+ self.device = self.get_conf('device', 'cpu')
+
+ # loads the processes from the config
+ self.load_processes(process_dict)
+
+ def run(self):
+ super().run()
+
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/ModJob.py b/jobs/ModJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..e37990de95a0d2ad78a94f9cdfd6dfbda0cdc529
--- /dev/null
+++ b/jobs/ModJob.py
@@ -0,0 +1,28 @@
+import os
+from collections import OrderedDict
+from jobs import BaseJob
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.train_tools import get_torch_dtype
+
+process_dict = {
+ 'rescale_lora': 'ModRescaleLoraProcess',
+}
+
+
+class ModJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.device = self.get_conf('device', 'cpu')
+
+ # loads the processes from the config
+ self.load_processes(process_dict)
+
+ def run(self):
+ super().run()
+
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py
new file mode 100644
index 0000000000000000000000000000000000000000..dda64e2d94171cf53dd02b279b0e0456dc013e09
--- /dev/null
+++ b/jobs/TrainJob.py
@@ -0,0 +1,49 @@
+import json
+import os
+
+from jobs import BaseJob
+from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
+from collections import OrderedDict
+from typing import List
+from jobs.process import BaseExtractProcess, TrainFineTuneProcess
+from datetime import datetime
+import yaml
+from toolkit.paths import REPOS_ROOT
+
+import sys
+
+sys.path.append(REPOS_ROOT)
+
+process_dict = {
+ 'vae': 'TrainVAEProcess',
+ 'slider': 'TrainSliderProcess',
+ 'slider_old': 'TrainSliderProcessOld',
+ 'lora_hack': 'TrainLoRAHack',
+ 'rescale_sd': 'TrainSDRescaleProcess',
+ 'esrgan': 'TrainESRGANProcess',
+ 'reference': 'TrainReferenceProcess',
+}
+
+
+class TrainJob(BaseJob):
+
+ def __init__(self, config: OrderedDict):
+ super().__init__(config)
+ self.training_folder = self.get_conf('training_folder', required=True)
+ self.is_v2 = self.get_conf('is_v2', False)
+ self.device = self.get_conf('device', 'cpu')
+ # self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
+ # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
+ self.log_dir = self.get_conf('log_dir', None)
+
+ # loads the processes from the config
+ self.load_processes(process_dict)
+
+
+ def run(self):
+ super().run()
+ print("")
+ print(f"Running {len(self.process)} process{'' if len(self.process) == 1 else 'es'}")
+
+ for process in self.process:
+ process.run()
diff --git a/jobs/__init__.py b/jobs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da6c22b1ddd0ea9248e5afbf9b2ba014c137c1a
--- /dev/null
+++ b/jobs/__init__.py
@@ -0,0 +1,7 @@
+from .BaseJob import BaseJob
+from .ExtractJob import ExtractJob
+from .TrainJob import TrainJob
+from .MergeJob import MergeJob
+from .ModJob import ModJob
+from .GenerateJob import GenerateJob
+from .ExtensionJob import ExtensionJob
diff --git a/jobs/process/BaseExtensionProcess.py b/jobs/process/BaseExtensionProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b53dc1c498e64bb4adbc2b967b329fdc4a374925
--- /dev/null
+++ b/jobs/process/BaseExtensionProcess.py
@@ -0,0 +1,19 @@
+from collections import OrderedDict
+from typing import ForwardRef
+from jobs.process.BaseProcess import BaseProcess
+
+
+class BaseExtensionProcess(BaseProcess):
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.process_id: int
+ self.config: OrderedDict
+ self.progress_bar: ForwardRef('tqdm') = None
+
+ def run(self):
+ super().run()
diff --git a/jobs/process/BaseExtractProcess.py b/jobs/process/BaseExtractProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac10da54d82f15c8264b2799b10b01bb5cf8dc66
--- /dev/null
+++ b/jobs/process/BaseExtractProcess.py
@@ -0,0 +1,86 @@
+import os
+from collections import OrderedDict
+
+from safetensors.torch import save_file
+
+from jobs.process.BaseProcess import BaseProcess
+from toolkit.metadata import get_meta_for_safetensors
+
+from typing import ForwardRef
+
+from toolkit.train_tools import get_torch_dtype
+
+
+class BaseExtractProcess(BaseProcess):
+
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.config: OrderedDict
+ self.output_folder: str
+ self.output_filename: str
+ self.output_path: str
+ self.process_id = process_id
+ self.job = job
+ self.config = config
+ self.dtype = self.get_conf('dtype', self.job.dtype)
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ self.extract_unet = self.get_conf('extract_unet', self.job.extract_unet)
+ self.extract_text_encoder = self.get_conf('extract_text_encoder', self.job.extract_text_encoder)
+
+ def run(self):
+ # here instead of init because child init needs to go first
+ self.output_path = self.get_output_path()
+ # implement in child class
+ # be sure to call super().run() first
+ pass
+
+ # you can override this in the child class if you want
+ # call super().get_output_path(prefix="your_prefix_", suffix="_your_suffix") to extend this
+ def get_output_path(self, prefix=None, suffix=None):
+ config_output_path = self.get_conf('output_path', None)
+ config_filename = self.get_conf('filename', None)
+ # replace [name] with name
+
+ if config_output_path is not None:
+ config_output_path = config_output_path.replace('[name]', self.job.name)
+ return config_output_path
+
+ if config_output_path is None and config_filename is not None:
+ # build the output path from the output folder and filename
+ return os.path.join(self.job.output_folder, config_filename)
+
+ # build our own
+
+ if suffix is None:
+ # we will just add process it to the end of the filename if there is more than one process
+ # and no other suffix was given
+ suffix = f"_{self.process_id}" if len(self.config['process']) > 1 else ''
+
+ if prefix is None:
+ prefix = ''
+
+ output_filename = f"{prefix}{self.output_filename}{suffix}"
+
+ return os.path.join(self.job.output_folder, output_filename)
+
+ def save(self, state_dict):
+ # prepare meta
+ save_meta = get_meta_for_safetensors(self.meta, self.job.name)
+
+ # save
+ os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
+
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(self.torch_dtype)
+ state_dict[key] = v
+
+ # having issues with meta
+ save_file(state_dict, self.output_path, save_meta)
+
+ print(f"Saved to {self.output_path}")
diff --git a/jobs/process/BaseMergeProcess.py b/jobs/process/BaseMergeProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..55dfec68ae62383afae539ff6cb51862033a7e10
--- /dev/null
+++ b/jobs/process/BaseMergeProcess.py
@@ -0,0 +1,46 @@
+import os
+from collections import OrderedDict
+
+from safetensors.torch import save_file
+
+from jobs.process.BaseProcess import BaseProcess
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.train_tools import get_torch_dtype
+
+
+class BaseMergeProcess(BaseProcess):
+
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.process_id: int
+ self.config: OrderedDict
+ self.output_path = self.get_conf('output_path', required=True)
+ self.dtype = self.get_conf('dtype', self.job.dtype)
+ self.torch_dtype = get_torch_dtype(self.dtype)
+
+ def run(self):
+ # implement in child class
+ # be sure to call super().run() first
+ pass
+
+ def save(self, state_dict):
+ # prepare meta
+ save_meta = get_meta_for_safetensors(self.meta, self.job.name)
+
+ # save
+ os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
+
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(self.torch_dtype)
+ state_dict[key] = v
+
+ # having issues with meta
+ save_file(state_dict, self.output_path, save_meta)
+
+ print(f"Saved to {self.output_path}")
diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..f064460750e6bf845e3ecf4fa9ebf476d47eb162
--- /dev/null
+++ b/jobs/process/BaseProcess.py
@@ -0,0 +1,61 @@
+import copy
+import json
+from collections import OrderedDict
+
+from toolkit.timer import Timer
+
+
+class BaseProcess(object):
+
+ def __init__(
+ self,
+ process_id: int,
+ job: 'BaseJob',
+ config: OrderedDict
+ ):
+ self.process_id = process_id
+ self.meta: OrderedDict
+ self.job = job
+ self.config = config
+ self.raw_process_config = config
+ self.name = self.get_conf('name', self.job.name)
+ self.meta = copy.deepcopy(self.job.meta)
+ self.timer: Timer = Timer(f'{self.name} Timer')
+ self.performance_log_every = self.get_conf('performance_log_every', 0)
+
+ print(json.dumps(self.config, indent=4))
+
+ def get_conf(self, key, default=None, required=False, as_type=None):
+ # split key by '.' and recursively get the value
+ keys = key.split('.')
+
+ # see if it exists in the config
+ value = self.config
+ for subkey in keys:
+ if subkey in value:
+ value = value[subkey]
+ else:
+ value = None
+ break
+
+ if value is not None:
+ if as_type is not None:
+ value = as_type(value)
+ return value
+ elif required:
+ raise ValueError(f'config file error. Missing "config.process[{self.process_id}].{key}" key')
+ else:
+ if as_type is not None and default is not None:
+ return as_type(default)
+ return default
+
+ def run(self):
+ # implement in child class
+ # be sure to call super().run() first incase something is added here
+ pass
+
+ def add_meta(self, additional_meta: OrderedDict):
+ self.meta.update(additional_meta)
+
+
+from jobs import BaseJob
diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..87984717bd648f9b1b0d21c3c73b3e870f908ad8
--- /dev/null
+++ b/jobs/process/BaseSDTrainProcess.py
@@ -0,0 +1,2105 @@
+import copy
+import glob
+import inspect
+import json
+import random
+import shutil
+from collections import OrderedDict
+import os
+import re
+from typing import Union, List, Optional
+
+import numpy as np
+import yaml
+from diffusers import T2IAdapter, ControlNetModel
+from diffusers.training_utils import compute_density_for_timestep_sampling
+from safetensors.torch import save_file, load_file
+# from lycoris.config import PRESET
+from torch.utils.data import DataLoader
+import torch
+import torch.backends.cuda
+from huggingface_hub import HfApi, Repository, interpreter_login
+from huggingface_hub.utils import HfFolder
+
+from toolkit.basic import value_map
+from toolkit.clip_vision_adapter import ClipVisionAdapter
+from toolkit.custom_adapter import CustomAdapter
+from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
+from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
+from toolkit.ema import ExponentialMovingAverage
+from toolkit.embedding import Embedding
+from toolkit.image_utils import show_tensors, show_latents, reduce_contrast
+from toolkit.ip_adapter import IPAdapter
+from toolkit.lora_special import LoRASpecialNetwork
+from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
+ lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
+from toolkit.lycoris_special import LycorisSpecialNetwork
+from toolkit.models.decorator import Decorator
+from toolkit.network_mixins import Network
+from toolkit.optimizer import get_optimizer
+from toolkit.paths import CONFIG_ROOT
+from toolkit.progress_bar import ToolkitProgressBar
+from toolkit.reference_adapter import ReferenceAdapter
+from toolkit.sampler import get_sampler
+from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
+ load_ip_adapter_model, load_custom_adapter_model
+
+from toolkit.scheduler import get_lr_scheduler
+from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
+from toolkit.stable_diffusion_model import StableDiffusion
+
+from jobs.process import BaseTrainProcess
+from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
+ parse_metadata_from_safetensors
+from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight
+import gc
+
+from tqdm import tqdm
+
+from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
+ GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
+ DecoratorConfig
+from toolkit.logging import create_logger
+from diffusers import FluxTransformer2DModel
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+class BaseSDTrainProcess(BaseTrainProcess):
+
+ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None):
+ super().__init__(process_id, job, config)
+ self.sd: StableDiffusion
+ self.embedding: Union[Embedding, None] = None
+
+ self.custom_pipeline = custom_pipeline
+ self.step_num = 0
+ self.start_step = 0
+ self.epoch_num = 0
+ # start at 1 so we can do a sample at the start
+ self.grad_accumulation_step = 1
+ # if true, then we do not do an optimizer step. We are accumulating gradients
+ self.is_grad_accumulation_step = False
+ self.device = self.get_conf('device', self.job.device)
+ self.device_torch = torch.device(self.device)
+ network_config = self.get_conf('network', None)
+ if network_config is not None:
+ self.network_config = NetworkConfig(**network_config)
+ else:
+ self.network_config = None
+ self.train_config = TrainConfig(**self.get_conf('train', {}))
+ model_config = self.get_conf('model', {})
+
+ # update modelconfig dtype to match train
+ model_config['dtype'] = self.train_config.dtype
+ self.model_config = ModelConfig(**model_config)
+
+ self.save_config = SaveConfig(**self.get_conf('save', {}))
+ self.sample_config = SampleConfig(**self.get_conf('sample', {}))
+ first_sample_config = self.get_conf('first_sample', None)
+ if first_sample_config is not None:
+ self.has_first_sample_requested = True
+ self.first_sample_config = SampleConfig(**first_sample_config)
+ else:
+ self.has_first_sample_requested = False
+ self.first_sample_config = self.sample_config
+ self.logging_config = LoggingConfig(**self.get_conf('logging', {}))
+ self.logger = create_logger(self.logging_config, config)
+ self.optimizer: torch.optim.Optimizer = None
+ self.lr_scheduler = None
+ self.data_loader: Union[DataLoader, None] = None
+ self.data_loader_reg: Union[DataLoader, None] = None
+ self.trigger_word = self.get_conf('trigger_word', None)
+
+ self.guidance_config: Union[GuidanceConfig, None] = None
+ guidance_config_raw = self.get_conf('guidance', None)
+ if guidance_config_raw is not None:
+ self.guidance_config = GuidanceConfig(**guidance_config_raw)
+
+ # store is all are cached. Allows us to not load vae if we don't need to
+ self.is_latents_cached = True
+ raw_datasets = self.get_conf('datasets', None)
+ if raw_datasets is not None and len(raw_datasets) > 0:
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
+ self.datasets = None
+ self.datasets_reg = None
+ self.params = []
+ if raw_datasets is not None and len(raw_datasets) > 0:
+ for raw_dataset in raw_datasets:
+ dataset = DatasetConfig(**raw_dataset)
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
+ if not is_caching:
+ self.is_latents_cached = False
+ if dataset.is_reg:
+ if self.datasets_reg is None:
+ self.datasets_reg = []
+ self.datasets_reg.append(dataset)
+ else:
+ if self.datasets is None:
+ self.datasets = []
+ self.datasets.append(dataset)
+
+ self.embed_config = None
+ embedding_raw = self.get_conf('embedding', None)
+ if embedding_raw is not None:
+ self.embed_config = EmbeddingConfig(**embedding_raw)
+
+ self.decorator_config: DecoratorConfig = None
+ decorator_raw = self.get_conf('decorator', None)
+ if decorator_raw is not None:
+ if not self.model_config.is_flux:
+ raise ValueError("Decorators are only supported for Flux models currently")
+ self.decorator_config = DecoratorConfig(**decorator_raw)
+
+ # t2i adapter
+ self.adapter_config = None
+ adapter_raw = self.get_conf('adapter', None)
+ if adapter_raw is not None:
+ self.adapter_config = AdapterConfig(**adapter_raw)
+ # sdxl adapters end in _xl. Only full_adapter_xl for now
+ if self.model_config.is_xl and not self.adapter_config.adapter_type.endswith('_xl'):
+ self.adapter_config.adapter_type += '_xl'
+
+ # to hold network if there is one
+ self.network: Union[Network, None] = None
+ self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel, None] = None
+ self.embedding: Union[Embedding, None] = None
+ self.decorator: Union[Decorator, None] = None
+
+ is_training_adapter = self.adapter_config is not None and self.adapter_config.train
+
+ self.do_lorm = self.get_conf('do_lorm', False)
+ self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio')
+ self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25)
+ # 'ratio', 0.25)
+
+ # get the device state preset based on what we are training
+ self.train_device_state_preset = get_train_sd_device_state_preset(
+ device=self.device_torch,
+ train_unet=self.train_config.train_unet,
+ train_text_encoder=self.train_config.train_text_encoder,
+ cached_latents=self.is_latents_cached,
+ train_lora=self.network_config is not None,
+ train_adapter=is_training_adapter,
+ train_embedding=self.embed_config is not None,
+ train_decorator=self.decorator_config is not None,
+ train_refiner=self.train_config.train_refiner,
+ unload_text_encoder=self.train_config.unload_text_encoder,
+ require_grads=False # we ensure them later
+ )
+
+ self.get_params_device_state_preset = get_train_sd_device_state_preset(
+ device=self.device_torch,
+ train_unet=self.train_config.train_unet,
+ train_text_encoder=self.train_config.train_text_encoder,
+ cached_latents=self.is_latents_cached,
+ train_lora=self.network_config is not None,
+ train_adapter=is_training_adapter,
+ train_embedding=self.embed_config is not None,
+ train_decorator=self.decorator_config is not None,
+ train_refiner=self.train_config.train_refiner,
+ unload_text_encoder=self.train_config.unload_text_encoder,
+ require_grads=True # We check for grads when getting params
+ )
+
+ # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
+ self.is_fine_tuning = True
+ if self.network_config is not None or is_training_adapter or self.embed_config is not None or self.decorator_config is not None:
+ self.is_fine_tuning = False
+
+ self.named_lora = False
+ if self.embed_config is not None or is_training_adapter:
+ self.named_lora = True
+ self.snr_gos: Union[LearnableSNRGamma, None] = None
+ self.ema: ExponentialMovingAverage = None
+
+ validate_configs(self.train_config, self.model_config, self.save_config)
+
+ def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
+ # override in subclass
+ return generate_image_config_list
+
+ def sample(self, step=None, is_first=False):
+ flush()
+ sample_folder = os.path.join(self.save_root, 'samples')
+ gen_img_config_list = []
+
+ sample_config = self.first_sample_config if is_first else self.sample_config
+ start_seed = sample_config.seed
+ current_seed = start_seed
+
+ test_image_paths = []
+ if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
+ test_image_path_list = self.adapter_config.test_img_path.split(',')
+ test_image_path_list = [p.strip() for p in test_image_path_list]
+ test_image_path_list = [p for p in test_image_path_list if p != '']
+ # divide up images so they are evenly distributed across prompts
+ for i in range(len(sample_config.prompts)):
+ test_image_paths.append(test_image_path_list[i % len(test_image_path_list)])
+
+ for i in range(len(sample_config.prompts)):
+ if sample_config.walk_seed:
+ current_seed = start_seed + i
+
+ step_num = ''
+ if step is not None:
+ # zero-pad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+
+ filename = f"[time]_{step_num}_[count].{self.sample_config.ext}"
+
+ output_path = os.path.join(sample_folder, filename)
+
+ prompt = sample_config.prompts[i]
+
+ # add embedding if there is one
+ # note: diffusers will automatically expand the trigger to the number of added tokens
+ # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
+ if self.embedding is not None:
+ prompt = self.embedding.inject_embedding_to_prompt(
+ prompt, expand_token=True, add_if_not_present=False
+ )
+ if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
+ prompt = self.adapter.inject_trigger_into_prompt(
+ prompt, expand_token=True, add_if_not_present=False
+ )
+ if self.trigger_word is not None:
+ prompt = self.sd.inject_trigger_into_prompt(
+ prompt, self.trigger_word, add_if_not_present=False
+ )
+
+ extra_args = {}
+ if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
+ extra_args['adapter_image_path'] = test_image_paths[i]
+
+ gen_img_config_list.append(GenerateImageConfig(
+ prompt=prompt, # it will autoparse the prompt
+ width=sample_config.width,
+ height=sample_config.height,
+ negative_prompt=sample_config.neg,
+ seed=current_seed,
+ guidance_scale=sample_config.guidance_scale,
+ guidance_rescale=sample_config.guidance_rescale,
+ num_inference_steps=sample_config.sample_steps,
+ network_multiplier=sample_config.network_multiplier,
+ output_path=output_path,
+ output_ext=sample_config.ext,
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
+ refiner_start_at=sample_config.refiner_start_at,
+ extra_values=sample_config.extra_values,
+ logger=self.logger,
+ **extra_args
+ ))
+
+ # post process
+ gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
+
+ # if we have an ema, set it to validation mode
+ if self.ema is not None:
+ self.ema.eval()
+
+ # send to be generated
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
+
+ if self.ema is not None:
+ self.ema.train()
+
+ def update_training_metadata(self):
+ o_dict = OrderedDict({
+ "training_info": self.get_training_info()
+ })
+ if self.model_config.is_v2:
+ o_dict['ss_v2'] = True
+ o_dict['ss_base_model_version'] = 'sd_2.1'
+
+ elif self.model_config.is_xl:
+ o_dict['ss_base_model_version'] = 'sdxl_1.0'
+ else:
+ o_dict['ss_base_model_version'] = 'sd_1.5'
+
+ o_dict = add_base_model_info_to_meta(
+ o_dict,
+ is_v2=self.model_config.is_v2,
+ is_xl=self.model_config.is_xl,
+ )
+ o_dict['ss_output_name'] = self.job.name
+
+ if self.trigger_word is not None:
+ # just so auto1111 will pick it up
+ o_dict['ss_tag_frequency'] = {
+ f"1_{self.trigger_word}": {
+ f"{self.trigger_word}": 1
+ }
+ }
+
+ self.add_meta(o_dict)
+
+ def get_training_info(self):
+ info = OrderedDict({
+ 'step': self.step_num,
+ 'epoch': self.epoch_num,
+ })
+ return info
+
+ def clean_up_saves(self):
+ # remove old saves
+ # get latest saved step
+ latest_item = None
+ if os.path.exists(self.save_root):
+ # pattern is {job_name}_{zero_filled_step} for both files and directories
+ pattern = f"{self.job.name}_*"
+ items = glob.glob(os.path.join(self.save_root, pattern))
+ # Separate files and directories
+ safetensors_files = [f for f in items if f.endswith('.safetensors')]
+ pt_files = [f for f in items if f.endswith('.pt')]
+ directories = [d for d in items if os.path.isdir(d) and not d.endswith('.safetensors')]
+ embed_files = []
+ # do embedding files
+ if self.embed_config is not None:
+ embed_pattern = f"{self.embed_config.trigger}_*"
+ embed_items = glob.glob(os.path.join(self.save_root, embed_pattern))
+ # will end in safetensors or pt
+ embed_files = [f for f in embed_items if f.endswith('.safetensors') or f.endswith('.pt')]
+
+ # check for critic files
+ critic_pattern = f"CRITIC_{self.job.name}_*"
+ critic_items = glob.glob(os.path.join(self.save_root, critic_pattern))
+
+ # Sort the lists by creation time if they are not empty
+ if safetensors_files:
+ safetensors_files.sort(key=os.path.getctime)
+ if pt_files:
+ pt_files.sort(key=os.path.getctime)
+ if directories:
+ directories.sort(key=os.path.getctime)
+ if embed_files:
+ embed_files.sort(key=os.path.getctime)
+ if critic_items:
+ critic_items.sort(key=os.path.getctime)
+
+ # Combine and sort the lists
+ combined_items = safetensors_files + directories + pt_files
+ combined_items.sort(key=os.path.getctime)
+
+ # Use slicing with a check to avoid 'NoneType' error
+ safetensors_to_remove = safetensors_files[
+ :-self.save_config.max_step_saves_to_keep] if safetensors_files else []
+ pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
+ directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
+ embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
+ critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
+
+ items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
+
+ # remove all but the latest max_step_saves_to_keep
+ # items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
+
+ # remove duplicates
+ items_to_remove = list(dict.fromkeys(items_to_remove))
+
+ for item in items_to_remove:
+ self.print(f"Removing old save: {item}")
+ if os.path.isdir(item):
+ shutil.rmtree(item)
+ else:
+ os.remove(item)
+ # see if a yaml file with same name exists
+ yaml_file = os.path.splitext(item)[0] + ".yaml"
+ if os.path.exists(yaml_file):
+ os.remove(yaml_file)
+ if combined_items:
+ latest_item = combined_items[-1]
+ return latest_item
+
+ def post_save_hook(self, save_path):
+ # override in subclass
+ pass
+
+ def save(self, step=None):
+ flush()
+ if self.ema is not None:
+ # always save params as ema
+ self.ema.eval()
+
+ if not os.path.exists(self.save_root):
+ os.makedirs(self.save_root, exist_ok=True)
+
+ step_num = ''
+ if step is not None:
+ # zeropad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+
+ self.update_training_metadata()
+ filename = f'{self.job.name}{step_num}.safetensors'
+ file_path = os.path.join(self.save_root, filename)
+
+ save_meta = copy.deepcopy(self.meta)
+ # get extra meta
+ if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
+ additional_save_meta = self.adapter.get_additional_save_metadata()
+ if additional_save_meta is not None:
+ for key, value in additional_save_meta.items():
+ save_meta[key] = value
+
+ # prepare meta
+ save_meta = get_meta_for_safetensors(save_meta, self.job.name)
+ if not self.is_fine_tuning:
+ if self.network is not None:
+ lora_name = self.job.name
+ if self.named_lora:
+ # add _lora to name
+ lora_name += '_LoRA'
+
+ filename = f'{lora_name}{step_num}.safetensors'
+ file_path = os.path.join(self.save_root, filename)
+ prev_multiplier = self.network.multiplier
+ self.network.multiplier = 1.0
+
+ # if we are doing embedding training as well, add that
+ embedding_dict = self.embedding.state_dict() if self.embedding else None
+ self.network.save_weights(
+ file_path,
+ dtype=get_torch_dtype(self.save_config.dtype),
+ metadata=save_meta,
+ extra_state_dict=embedding_dict
+ )
+ self.network.multiplier = prev_multiplier
+ # if we have an embedding as well, pair it with the network
+
+ # even if added to lora, still save the trigger version
+ if self.embedding is not None:
+ emb_filename = f'{self.embed_config.trigger}{step_num}.safetensors'
+ emb_file_path = os.path.join(self.save_root, emb_filename)
+ # for combo, above will get it
+ # set current step
+ self.embedding.step = self.step_num
+ # change filename to pt if that is set
+ if self.embed_config.save_format == "pt":
+ # replace extension
+ emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
+ self.embedding.save(emb_file_path)
+
+ if self.decorator is not None:
+ dec_filename = f'{self.job.name}{step_num}.safetensors'
+ dec_file_path = os.path.join(self.save_root, dec_filename)
+ decorator_state_dict = self.decorator.state_dict()
+ for key, value in decorator_state_dict.items():
+ if isinstance(value, torch.Tensor):
+ decorator_state_dict[key] = value.clone().to('cpu', dtype=get_torch_dtype(self.save_config.dtype))
+ save_file(
+ decorator_state_dict,
+ dec_file_path,
+ metadata=save_meta,
+ )
+
+ if self.adapter is not None and self.adapter_config.train:
+ adapter_name = self.job.name
+ if self.network_config is not None or self.embedding is not None:
+ # add _lora to name
+ if self.adapter_config.type == 't2i':
+ adapter_name += '_t2i'
+ elif self.adapter_config.type == 'control_net':
+ adapter_name += '_cn'
+ elif self.adapter_config.type == 'clip':
+ adapter_name += '_clip'
+ elif self.adapter_config.type.startswith('ip'):
+ adapter_name += '_ip'
+ else:
+ adapter_name += '_adapter'
+
+ filename = f'{adapter_name}{step_num}.safetensors'
+ file_path = os.path.join(self.save_root, filename)
+ # save adapter
+ state_dict = self.adapter.state_dict()
+ if self.adapter_config.type == 't2i':
+ save_t2i_from_diffusers(
+ state_dict,
+ output_file=file_path,
+ meta=save_meta,
+ dtype=get_torch_dtype(self.save_config.dtype)
+ )
+ elif self.adapter_config.type == 'control_net':
+ # save in diffusers format
+ name_or_path = file_path.replace('.safetensors', '')
+ # move it to the new dtype and cpu
+ orig_device = self.adapter.device
+ orig_dtype = self.adapter.dtype
+ self.adapter = self.adapter.to(torch.device('cpu'), dtype=get_torch_dtype(self.save_config.dtype))
+ self.adapter.save_pretrained(
+ name_or_path,
+ dtype=get_torch_dtype(self.save_config.dtype),
+ safe_serialization=True
+ )
+ meta_path = os.path.join(name_or_path, 'aitk_meta.yaml')
+ with open(meta_path, 'w') as f:
+ yaml.dump(self.meta, f)
+ # move it back
+ self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
+ else:
+ direct_save = False
+ if self.adapter_config.train_only_image_encoder:
+ direct_save = True
+ if self.adapter_config.type == 'redux':
+ direct_save = True
+ save_ip_adapter_from_diffusers(
+ state_dict,
+ output_file=file_path,
+ meta=save_meta,
+ dtype=get_torch_dtype(self.save_config.dtype),
+ direct_save=direct_save
+ )
+ else:
+ if self.save_config.save_format == "diffusers":
+ # saving as a folder path
+ file_path = file_path.replace('.safetensors', '')
+ # convert it back to normal object
+ save_meta = parse_metadata_from_safetensors(save_meta)
+
+ if self.sd.refiner_unet and self.train_config.train_refiner:
+ # save refiner
+ refiner_name = self.job.name + '_refiner'
+ filename = f'{refiner_name}{step_num}.safetensors'
+ file_path = os.path.join(self.save_root, filename)
+ self.sd.save_refiner(
+ file_path,
+ save_meta,
+ get_torch_dtype(self.save_config.dtype)
+ )
+ if self.train_config.train_unet or self.train_config.train_text_encoder:
+ self.sd.save(
+ file_path,
+ save_meta,
+ get_torch_dtype(self.save_config.dtype)
+ )
+
+ # save learnable params as json if we have thim
+ if self.snr_gos:
+ json_data = {
+ 'offset_1': self.snr_gos.offset_1.item(),
+ 'offset_2': self.snr_gos.offset_2.item(),
+ 'scale': self.snr_gos.scale.item(),
+ 'gamma': self.snr_gos.gamma.item(),
+ }
+ path_to_save = file_path = os.path.join(self.save_root, 'learnable_snr.json')
+ with open(path_to_save, 'w') as f:
+ json.dump(json_data, f, indent=4)
+
+ # save optimizer
+ if self.optimizer is not None:
+ try:
+ filename = f'optimizer.pt'
+ file_path = os.path.join(self.save_root, filename)
+ state_dict = self.optimizer.state_dict()
+ torch.save(state_dict, file_path)
+ except Exception as e:
+ print(e)
+ print("Could not save optimizer")
+
+ self.print(f"Saved to {file_path}")
+ self.clean_up_saves()
+ self.post_save_hook(file_path)
+
+ if self.ema is not None:
+ self.ema.train()
+ flush()
+
+ # Called before the model is loaded
+ def hook_before_model_load(self):
+ # override in subclass
+ pass
+
+ def hook_after_model_load(self):
+ # override in subclass
+ pass
+
+ def hook_add_extra_train_params(self, params):
+ # override in subclass
+ return params
+
+ def hook_before_train_loop(self):
+ self.logger.start()
+
+ def ensure_params_requires_grad(self, force=False):
+ if self.train_config.do_paramiter_swapping and not force:
+ # the optimizer will handle this if we are not forcing
+ return
+ for group in self.params:
+ for param in group['params']:
+ if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
+ param.requires_grad_(True)
+
+ def setup_ema(self):
+ if self.train_config.ema_config.use_ema:
+ # our params are in groups. We need them as a single iterable
+ params = []
+ for group in self.optimizer.param_groups:
+ for param in group['params']:
+ params.append(param)
+ self.ema = ExponentialMovingAverage(
+ params,
+ decay=self.train_config.ema_config.ema_decay,
+ use_feedback=self.train_config.ema_config.use_feedback,
+ param_multiplier=self.train_config.ema_config.param_multiplier,
+ )
+
+ def before_dataset_load(self):
+ pass
+
+ def get_params(self):
+ # you can extend this in subclass to get params
+ # otherwise params will be gathered through normal means
+ return None
+
+ def hook_train_loop(self, batch):
+ # return loss
+ return 0.0
+
+ def get_latest_save_path(self, name=None, post=''):
+ if name == None:
+ name = self.job.name
+ # get latest saved step
+ latest_path = None
+ if os.path.exists(self.save_root):
+ # Define patterns for both files and directories
+ patterns = [
+ f"{name}*{post}.safetensors",
+ f"{name}*{post}.pt",
+ f"{name}*{post}"
+ ]
+ # Search for both files and directories
+ paths = []
+ for pattern in patterns:
+ paths.extend(glob.glob(os.path.join(self.save_root, pattern)))
+
+ # Filter out non-existent paths and sort by creation time
+ if paths:
+ paths = [p for p in paths if os.path.exists(p)]
+ # remove false positives
+ if '_LoRA' not in name:
+ paths = [p for p in paths if '_LoRA' not in p]
+ if '_refiner' not in name:
+ paths = [p for p in paths if '_refiner' not in p]
+ if '_t2i' not in name:
+ paths = [p for p in paths if '_t2i' not in p]
+ if '_cn' not in name:
+ paths = [p for p in paths if '_cn' not in p]
+
+ if len(paths) > 0:
+ latest_path = max(paths, key=os.path.getctime)
+
+ return latest_path
+
+ def load_training_state_from_metadata(self, path):
+ meta = None
+ # if path is folder, then it is diffusers
+ if os.path.isdir(path):
+ meta_path = os.path.join(path, 'aitk_meta.yaml')
+ # load it
+ if os.path.exists(meta_path):
+ with open(meta_path, 'r') as f:
+ meta = yaml.load(f, Loader=yaml.FullLoader)
+ else:
+ meta = load_metadata_from_safetensors(path)
+ # if 'training_info' in Orderdict keys
+ if meta is not None and 'training_info' in meta and 'step' in meta['training_info'] and self.train_config.start_step is None:
+ self.step_num = meta['training_info']['step']
+ if 'epoch' in meta['training_info']:
+ self.epoch_num = meta['training_info']['epoch']
+ self.start_step = self.step_num
+ print(f"Found step {self.step_num} in metadata, starting from there")
+
+ def load_weights(self, path):
+ if self.network is not None:
+ extra_weights = self.network.load_weights(path)
+ self.load_training_state_from_metadata(path)
+ return extra_weights
+ else:
+ print("load_weights not implemented for non-network models")
+ return None
+
+ def apply_snr(self, seperated_loss, timesteps):
+ if self.train_config.learnable_snr_gos:
+ # add snr_gamma
+ seperated_loss = apply_learnable_snr_gos(seperated_loss, timesteps, self.snr_gos)
+ elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001:
+ # add snr_gamma
+ seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
+ elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
+ # add min_snr_gamma
+ seperated_loss = apply_snr_weight(seperated_loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
+
+ return seperated_loss
+
+ def load_lorm(self):
+ latest_save_path = self.get_latest_save_path()
+ if latest_save_path is not None:
+ # hacky way to reload weights for now
+ # todo, do this
+ state_dict = load_file(latest_save_path, device=self.device)
+ self.sd.unet.load_state_dict(state_dict)
+
+ meta = load_metadata_from_safetensors(latest_save_path)
+ # if 'training_info' in Orderdict keys
+ if 'training_info' in meta and 'step' in meta['training_info']:
+ self.step_num = meta['training_info']['step']
+ if 'epoch' in meta['training_info']:
+ self.epoch_num = meta['training_info']['epoch']
+ self.start_step = self.step_num
+ print(f"Found step {self.step_num} in metadata, starting from there")
+
+ # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
+ # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
+ # sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
+ # schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device_torch, )
+ # timesteps = timesteps.to(self.device_torch, )
+ #
+ # # step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+ # step_indices = [t for t in timesteps]
+ #
+ # sigma = sigmas[step_indices].flatten()
+ # while len(sigma.shape) < n_dim:
+ # sigma = sigma.unsqueeze(-1)
+ # return sigma
+
+ def load_additional_training_modules(self, params):
+ # override in subclass
+ return params
+
+ def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype)
+ schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device)
+ timesteps = timesteps.to(self.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ def get_noise(self, latents, batch_size, dtype=torch.float32):
+ # get noise
+ noise = self.sd.get_latent_noise(
+ height=latents.shape[2],
+ width=latents.shape[3],
+ batch_size=batch_size,
+ noise_offset=self.train_config.noise_offset,
+ ).to(self.device_torch, dtype=dtype)
+
+ if self.train_config.random_noise_shift > 0.0:
+ # get random noise -1 to 1
+ noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device,
+ dtype=noise.dtype) * 2 - 1
+
+ # multiply by shift amount
+ noise_shift *= self.train_config.random_noise_shift
+
+ # add to noise
+ noise += noise_shift
+
+ # standardize the noise
+ std = noise.std(dim=(2, 3), keepdim=True)
+ normalizer = 1 / (std + 1e-6)
+ noise = noise * normalizer
+
+ return noise
+
+ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
+ with torch.no_grad():
+ with self.timer('prepare_prompt'):
+ prompts = batch.get_caption_list()
+ is_reg_list = batch.get_is_reg_list()
+
+ is_any_reg = any([is_reg for is_reg in is_reg_list])
+
+ do_double = self.train_config.short_and_long_captions and not is_any_reg
+
+ if self.train_config.short_and_long_captions and do_double:
+ # dont do this with regs. No point
+
+ # double batch and add short captions to the end
+ prompts = prompts + batch.get_caption_short_list()
+ is_reg_list = is_reg_list + is_reg_list
+ if self.model_config.refiner_name_or_path is not None and self.train_config.train_unet:
+ prompts = prompts + prompts
+ is_reg_list = is_reg_list + is_reg_list
+
+ conditioned_prompts = []
+
+ for prompt, is_reg in zip(prompts, is_reg_list):
+
+ # make sure the embedding is in the prompts
+ if self.embedding is not None:
+ prompt = self.embedding.inject_embedding_to_prompt(
+ prompt,
+ expand_token=True,
+ add_if_not_present=not is_reg,
+ )
+
+ if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
+ prompt = self.adapter.inject_trigger_into_prompt(
+ prompt,
+ expand_token=True,
+ add_if_not_present=not is_reg,
+ )
+
+ # make sure trigger is in the prompts if not a regularization run
+ if self.trigger_word is not None:
+ prompt = self.sd.inject_trigger_into_prompt(
+ prompt,
+ trigger=self.trigger_word,
+ add_if_not_present=not is_reg,
+ )
+
+ if not is_reg and self.train_config.prompt_saturation_chance > 0.0:
+ # do random prompt saturation by expanding the prompt to hit at least 77 tokens
+ if random.random() < self.train_config.prompt_saturation_chance:
+ est_num_tokens = len(prompt.split(' '))
+ if est_num_tokens < 77:
+ num_repeats = int(77 / est_num_tokens) + 1
+ prompt = ', '.join([prompt] * num_repeats)
+
+
+ conditioned_prompts.append(prompt)
+
+ with self.timer('prepare_latents'):
+ dtype = get_torch_dtype(self.train_config.dtype)
+ imgs = None
+ is_reg = any(batch.get_is_reg_list())
+ if batch.tensor is not None:
+ imgs = batch.tensor
+ imgs = imgs.to(self.device_torch, dtype=dtype)
+ # dont adjust for regs.
+ if self.train_config.img_multiplier is not None and not is_reg:
+ # do it ad contrast
+ imgs = reduce_contrast(imgs, self.train_config.img_multiplier)
+ if batch.latents is not None:
+ latents = batch.latents.to(self.device_torch, dtype=dtype)
+ batch.latents = latents
+ else:
+ # normalize to
+ if self.train_config.standardize_images:
+ if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
+ target_mean_list = [0.0002, -0.1034, -0.1879]
+ target_std_list = [0.5436, 0.5116, 0.5033]
+ else:
+ target_mean_list = [-0.0739, -0.1597, -0.2380]
+ target_std_list = [0.5623, 0.5295, 0.5347]
+ # Mean: tensor([-0.0739, -0.1597, -0.2380])
+ # Standard Deviation: tensor([0.5623, 0.5295, 0.5347])
+ imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True)
+ imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True)
+ imgs = (imgs - imgs_channel_mean) / imgs_channel_std
+ target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
+ target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
+ # expand them to match dim
+ target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+
+ imgs = imgs * target_std + target_mean
+ batch.tensor = imgs
+
+ # show_tensors(imgs, 'imgs')
+
+ latents = self.sd.encode_images(imgs)
+ batch.latents = latents
+
+ if self.train_config.standardize_latents:
+ if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
+ target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164]
+ target_std_list = [0.8979, 0.7505, 0.9150, 0.7451]
+ else:
+ target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929]
+ target_std_list = [0.8560, 0.9629, 0.7778, 0.6719]
+
+ latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True)
+ latents_channel_std = latents.std(dim=(2, 3), keepdim=True)
+ latents = (latents - latents_channel_mean) / latents_channel_std
+ target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype)
+ target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype)
+ # expand them to match dim
+ target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+
+ latents = latents * target_std + target_mean
+ batch.latents = latents
+
+ # show_latents(latents, self.sd.vae, 'latents')
+
+
+ if batch.unconditional_tensor is not None and batch.unconditional_latents is None:
+ unconditional_imgs = batch.unconditional_tensor
+ unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype)
+ unconditional_latents = self.sd.encode_images(unconditional_imgs)
+ batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier
+
+ unaugmented_latents = None
+ if self.train_config.loss_target == 'differential_noise':
+ # we determine noise from the differential of the latents
+ unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
+
+ batch_size = len(batch.file_items)
+ min_noise_steps = self.train_config.min_denoising_steps
+ max_noise_steps = self.train_config.max_denoising_steps
+ if self.model_config.refiner_name_or_path is not None:
+ # if we are not training the unet, then we are only doing refiner and do not need to double up
+ if self.train_config.train_unet:
+ max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
+ do_double = True
+ else:
+ min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at)
+ do_double = False
+
+ with self.timer('prepare_noise'):
+ num_train_timesteps = self.train_config.num_train_timesteps
+
+ if self.train_config.noise_scheduler in ['custom_lcm']:
+ # we store this value on our custom one
+ self.sd.noise_scheduler.set_timesteps(
+ self.sd.noise_scheduler.train_timesteps, device=self.device_torch
+ )
+ elif self.train_config.noise_scheduler in ['lcm']:
+ self.sd.noise_scheduler.set_timesteps(
+ num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
+ )
+ elif self.train_config.noise_scheduler == 'flowmatch':
+ linear_timesteps = any([
+ self.train_config.linear_timesteps,
+ self.train_config.linear_timesteps2,
+ self.train_config.timestep_type == 'linear',
+ ])
+ self.sd.noise_scheduler.set_train_timesteps(
+ num_train_timesteps,
+ device=self.device_torch,
+ linear=linear_timesteps
+ )
+ else:
+ self.sd.noise_scheduler.set_timesteps(
+ num_train_timesteps, device=self.device_torch
+ )
+
+ content_or_style = self.train_config.content_or_style
+ if is_reg:
+ content_or_style = self.train_config.content_or_style_reg
+
+ # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
+ if content_or_style in ['style', 'content']:
+ # this is from diffusers training code
+ # Cubic sampling for favoring later or earlier timesteps
+ # For more details about why cubic sampling is used for content / structure,
+ # refer to section 3.4 of https://arxiv.org/abs/2302.08453
+
+ # for content / structure, it is best to favor earlier timesteps
+ # for style, it is best to favor later timesteps
+
+ orig_timesteps = torch.rand((batch_size,), device=latents.device)
+
+ if content_or_style == 'content':
+ timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps
+ elif content_or_style == 'style':
+ timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps
+
+ timestep_indices = value_map(
+ timestep_indices,
+ 0,
+ self.train_config.num_train_timesteps - 1,
+ min_noise_steps,
+ max_noise_steps - 1
+ )
+ timestep_indices = timestep_indices.long().clamp(
+ min_noise_steps + 1,
+ max_noise_steps - 1
+ )
+
+ elif content_or_style == 'balanced':
+ if min_noise_steps == max_noise_steps:
+ timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
+ else:
+ # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here
+ timestep_indices = torch.randint(
+ min_noise_steps + 1,
+ max_noise_steps - 1,
+ (batch_size,),
+ device=self.device_torch
+ )
+ timestep_indices = timestep_indices.long()
+ else:
+ raise ValueError(f"Unknown content_or_style {content_or_style}")
+
+ # do flow matching
+ # if self.sd.is_flow_matching:
+ # u = compute_density_for_timestep_sampling(
+ # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
+ # batch_size=batch_size,
+ # logit_mean=0.0,
+ # logit_std=1.0,
+ # mode_scale=1.29,
+ # )
+ # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
+ # convert the timestep_indices to a timestep
+ timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
+ timesteps = torch.stack(timesteps, dim=0)
+
+ # get noise
+ noise = self.get_noise(latents, batch_size, dtype=dtype)
+
+ # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
+ # this will negate any noise offsets
+ if self.train_config.dynamic_noise_offset and not is_reg:
+ latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2
+ # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel
+ noise = noise + latents_channel_mean
+
+ if self.train_config.loss_target == 'differential_noise':
+ differential = latents - unaugmented_latents
+ # add noise to differential
+ # noise = noise + differential
+ noise = noise + (differential * 0.5)
+ # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max())
+ latents = unaugmented_latents
+
+ noise_multiplier = self.train_config.noise_multiplier
+
+ noise = noise * noise_multiplier
+
+ latent_multiplier = self.train_config.latent_multiplier
+
+ # handle adaptive scaling mased on std
+ if self.train_config.adaptive_scaling_factor:
+ std = latents.std(dim=(2, 3), keepdim=True)
+ normalizer = 1 / (std + 1e-6)
+ latent_multiplier = normalizer
+
+ latents = latents * latent_multiplier
+ batch.latents = latents
+
+ # normalize latents to a mean of 0 and an std of 1
+ # mean_zero_latents = latents - latents.mean()
+ # latents = mean_zero_latents / mean_zero_latents.std()
+
+ if batch.unconditional_latents is not None:
+ batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier
+
+
+ noisy_latents = self.sd.add_noise(latents, noise, timesteps)
+
+ # determine scaled noise
+ # todo do we need to scale this or does it always predict full intensity
+ # noise = noisy_latents - latents
+
+ # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
+ if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
+ sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
+ # add it to the batch
+ batch.sigmas = sigmas
+ # todo is this for sdxl? find out where this came from originally
+ # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
+
+ def double_up_tensor(tensor: torch.Tensor):
+ if tensor is None:
+ return None
+ return torch.cat([tensor, tensor], dim=0)
+
+ if do_double:
+ if self.model_config.refiner_name_or_path:
+ # apply refiner double up
+ refiner_timesteps = torch.randint(
+ max_noise_steps,
+ self.train_config.max_denoising_steps,
+ (batch_size,),
+ device=self.device_torch
+ )
+ refiner_timesteps = refiner_timesteps.long()
+ # add our new timesteps on to end
+ timesteps = torch.cat([timesteps, refiner_timesteps], dim=0)
+
+ refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps)
+ noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0)
+
+ else:
+ # just double it
+ noisy_latents = double_up_tensor(noisy_latents)
+ timesteps = double_up_tensor(timesteps)
+
+ noise = double_up_tensor(noise)
+ # prompts are already updated above
+ imgs = double_up_tensor(imgs)
+ batch.mask_tensor = double_up_tensor(batch.mask_tensor)
+ batch.control_tensor = double_up_tensor(batch.control_tensor)
+
+ noisy_latent_multiplier = self.train_config.noisy_latent_multiplier
+
+ if noisy_latent_multiplier != 1.0:
+ noisy_latents = noisy_latents * noisy_latent_multiplier
+
+ # remove grads for these
+ noisy_latents.requires_grad = False
+ noisy_latents = noisy_latents.detach()
+ noise.requires_grad = False
+ noise = noise.detach()
+
+ return noisy_latents, noise, timesteps, conditioned_prompts, imgs
+
+ def setup_adapter(self):
+ # t2i adapter
+ is_t2i = self.adapter_config.type == 't2i'
+ is_control_net = self.adapter_config.type == 'control_net'
+ if self.adapter_config.type == 't2i':
+ suffix = 't2i'
+ elif self.adapter_config.type == 'control_net':
+ suffix = 'cn'
+ elif self.adapter_config.type == 'clip':
+ suffix = 'clip'
+ elif self.adapter_config.type == 'reference':
+ suffix = 'ref'
+ elif self.adapter_config.type.startswith('ip'):
+ suffix = 'ip'
+ else:
+ suffix = 'adapter'
+ adapter_name = self.name
+ if self.network_config is not None:
+ adapter_name = f"{adapter_name}_{suffix}"
+ latest_save_path = self.get_latest_save_path(adapter_name)
+
+ dtype = get_torch_dtype(self.train_config.dtype)
+ if is_t2i:
+ # if we do not have a last save path and we have a name_or_path,
+ # load from that
+ if latest_save_path is None and self.adapter_config.name_or_path is not None:
+ self.adapter = T2IAdapter.from_pretrained(
+ self.adapter_config.name_or_path,
+ torch_dtype=get_torch_dtype(self.train_config.dtype),
+ varient="fp16",
+ # use_safetensors=True,
+ )
+ else:
+ self.adapter = T2IAdapter(
+ in_channels=self.adapter_config.in_channels,
+ channels=self.adapter_config.channels,
+ num_res_blocks=self.adapter_config.num_res_blocks,
+ downscale_factor=self.adapter_config.downscale_factor,
+ adapter_type=self.adapter_config.adapter_type,
+ )
+ elif is_control_net:
+ if self.adapter_config.name_or_path is None:
+ raise ValueError("ControlNet requires a name_or_path to load from currently")
+ load_from_path = self.adapter_config.name_or_path
+ if latest_save_path is not None:
+ load_from_path = latest_save_path
+ self.adapter = ControlNetModel.from_pretrained(
+ load_from_path,
+ torch_dtype=get_torch_dtype(self.train_config.dtype),
+ )
+ elif self.adapter_config.type == 'clip':
+ self.adapter = ClipVisionAdapter(
+ sd=self.sd,
+ adapter_config=self.adapter_config,
+ )
+ elif self.adapter_config.type == 'reference':
+ self.adapter = ReferenceAdapter(
+ sd=self.sd,
+ adapter_config=self.adapter_config,
+ )
+ elif self.adapter_config.type.startswith('ip'):
+ self.adapter = IPAdapter(
+ sd=self.sd,
+ adapter_config=self.adapter_config,
+ )
+ if self.train_config.gradient_checkpointing:
+ self.adapter.enable_gradient_checkpointing()
+ else:
+ self.adapter = CustomAdapter(
+ sd=self.sd,
+ adapter_config=self.adapter_config,
+ )
+ self.adapter.to(self.device_torch, dtype=dtype)
+ if latest_save_path is not None and not is_control_net:
+ # load adapter from path
+ print(f"Loading adapter from {latest_save_path}")
+ if is_t2i:
+ loaded_state_dict = load_t2i_model(
+ latest_save_path,
+ self.device,
+ dtype=dtype
+ )
+ self.adapter.load_state_dict(loaded_state_dict)
+ elif self.adapter_config.type.startswith('ip'):
+ # ip adapter
+ loaded_state_dict = load_ip_adapter_model(
+ latest_save_path,
+ self.device,
+ dtype=dtype,
+ direct_load=self.adapter_config.train_only_image_encoder
+ )
+ self.adapter.load_state_dict(loaded_state_dict)
+ else:
+ # custom adapter
+ loaded_state_dict = load_custom_adapter_model(
+ latest_save_path,
+ self.device,
+ dtype=dtype
+ )
+ self.adapter.load_state_dict(loaded_state_dict)
+ if latest_save_path is not None and self.adapter_config.train:
+ self.load_training_state_from_metadata(latest_save_path)
+ # set trainable params
+ self.sd.adapter = self.adapter
+
+ def run(self):
+ # torch.autograd.set_detect_anomaly(True)
+ # run base process run
+ BaseTrainProcess.run(self)
+ params = []
+
+ ### HOOK ###
+ self.hook_before_model_load()
+ model_config_to_load = copy.deepcopy(self.model_config)
+
+ if self.is_fine_tuning:
+ # get the latest checkpoint
+ # check to see if we have a latest save
+ latest_save_path = self.get_latest_save_path()
+
+ if latest_save_path is not None:
+ print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
+ model_config_to_load.name_or_path = latest_save_path
+ self.load_training_state_from_metadata(latest_save_path)
+
+ # get the noise scheduler
+ sampler = get_sampler(
+ self.train_config.noise_scheduler,
+ {
+ "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
+ },
+ 'sd' if not self.model_config.is_pixart else 'pixart'
+ )
+
+ if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
+ previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
+ if previous_refiner_save is not None:
+ model_config_to_load.refiner_name_or_path = previous_refiner_save
+ self.load_training_state_from_metadata(previous_refiner_save)
+
+ self.sd = StableDiffusion(
+ device=self.device,
+ model_config=model_config_to_load,
+ dtype=self.train_config.dtype,
+ custom_pipeline=self.custom_pipeline,
+ noise_scheduler=sampler,
+ )
+ # run base sd process run
+ self.sd.load_model()
+
+ dtype = get_torch_dtype(self.train_config.dtype)
+
+ # model is loaded from BaseSDProcess
+ unet = self.sd.unet
+ vae = self.sd.vae
+ tokenizer = self.sd.tokenizer
+ text_encoder = self.sd.text_encoder
+ noise_scheduler = self.sd.noise_scheduler
+
+ if self.train_config.xformers:
+ vae.enable_xformers_memory_efficient_attention()
+ unet.enable_xformers_memory_efficient_attention()
+ if isinstance(text_encoder, list):
+ for te in text_encoder:
+ # if it has it
+ if hasattr(te, 'enable_xformers_memory_efficient_attention'):
+ te.enable_xformers_memory_efficient_attention()
+ if self.train_config.sdp:
+ torch.backends.cuda.enable_math_sdp(True)
+ torch.backends.cuda.enable_flash_sdp(True)
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
+
+ # # check if we have sage and is flux
+ # if self.sd.is_flux:
+ # # try_to_activate_sage_attn()
+ # try:
+ # from sageattention import sageattn
+ # from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0
+ # model: FluxTransformer2DModel = self.sd.unet
+ # # enable sage attention on each block
+ # for block in model.transformer_blocks:
+ # processor = FluxSageAttnProcessor2_0()
+ # block.attn.set_processor(processor)
+ # for block in model.single_transformer_blocks:
+ # processor = FluxSageAttnProcessor2_0()
+ # block.attn.set_processor(processor)
+
+ # except ImportError:
+ # print("sage attention is not installed. Using SDP instead")
+
+ if self.train_config.gradient_checkpointing:
+ if self.sd.is_flux:
+ unet.gradient_checkpointing = True
+ else:
+ unet.enable_gradient_checkpointing()
+ if isinstance(text_encoder, list):
+ for te in text_encoder:
+ if hasattr(te, 'enable_gradient_checkpointing'):
+ te.enable_gradient_checkpointing()
+ if hasattr(te, "gradient_checkpointing_enable"):
+ te.gradient_checkpointing_enable()
+ else:
+ if hasattr(text_encoder, 'enable_gradient_checkpointing'):
+ text_encoder.enable_gradient_checkpointing()
+ if hasattr(text_encoder, "gradient_checkpointing_enable"):
+ text_encoder.gradient_checkpointing_enable()
+
+ if self.sd.refiner_unet is not None:
+ self.sd.refiner_unet.to(self.device_torch, dtype=dtype)
+ self.sd.refiner_unet.requires_grad_(False)
+ self.sd.refiner_unet.eval()
+ if self.train_config.xformers:
+ self.sd.refiner_unet.enable_xformers_memory_efficient_attention()
+ if self.train_config.gradient_checkpointing:
+ self.sd.refiner_unet.enable_gradient_checkpointing()
+
+ if isinstance(text_encoder, list):
+ for te in text_encoder:
+ te.requires_grad_(False)
+ te.eval()
+ else:
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ unet.to(self.device_torch, dtype=dtype)
+ unet.requires_grad_(False)
+ unet.eval()
+ vae = vae.to(torch.device('cpu'), dtype=dtype)
+ vae.requires_grad_(False)
+ vae.eval()
+ if self.train_config.learnable_snr_gos:
+ self.snr_gos = LearnableSNRGamma(
+ self.sd.noise_scheduler, device=self.device_torch
+ )
+ # check to see if previous settings exist
+ path_to_load = os.path.join(self.save_root, 'learnable_snr.json')
+ if os.path.exists(path_to_load):
+ with open(path_to_load, 'r') as f:
+ json_data = json.load(f)
+ if 'offset' in json_data:
+ # legacy
+ self.snr_gos.offset_2.data = torch.tensor(json_data['offset'], device=self.device_torch)
+ else:
+ self.snr_gos.offset_1.data = torch.tensor(json_data['offset_1'], device=self.device_torch)
+ self.snr_gos.offset_2.data = torch.tensor(json_data['offset_2'], device=self.device_torch)
+ self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
+ self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
+
+ self.hook_after_model_load()
+ flush()
+ if not self.is_fine_tuning:
+ if self.network_config is not None:
+ # TODO should we completely switch to LycorisSpecialNetwork?
+ network_kwargs = self.network_config.network_kwargs
+ is_lycoris = False
+ is_lorm = self.network_config.type.lower() == 'lorm'
+ # default to LoCON if there are any conv layers or if it is named
+ NetworkClass = LoRASpecialNetwork
+ if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
+ NetworkClass = LycorisSpecialNetwork
+ is_lycoris = True
+
+ if is_lorm:
+ network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains
+ network_kwargs['parameter_threshold'] = lorm_parameter_threshold
+ network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE
+
+ # if is_lycoris:
+ # preset = PRESET['full']
+ # NetworkClass.apply_preset(preset)
+
+ self.network = NetworkClass(
+ text_encoder=text_encoder,
+ unet=unet,
+ lora_dim=self.network_config.linear,
+ multiplier=1.0,
+ alpha=self.network_config.linear_alpha,
+ train_unet=self.train_config.train_unet,
+ train_text_encoder=self.train_config.train_text_encoder,
+ conv_lora_dim=self.network_config.conv,
+ conv_alpha=self.network_config.conv_alpha,
+ is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
+ is_v2=self.model_config.is_v2,
+ is_v3=self.model_config.is_v3,
+ is_pixart=self.model_config.is_pixart,
+ is_auraflow=self.model_config.is_auraflow,
+ is_flux=self.model_config.is_flux,
+ is_ssd=self.model_config.is_ssd,
+ is_vega=self.model_config.is_vega,
+ dropout=self.network_config.dropout,
+ use_text_encoder_1=self.model_config.use_text_encoder_1,
+ use_text_encoder_2=self.model_config.use_text_encoder_2,
+ use_bias=is_lorm,
+ is_lorm=is_lorm,
+ network_config=self.network_config,
+ network_type=self.network_config.type,
+ transformer_only=self.network_config.transformer_only,
+ **network_kwargs
+ )
+
+
+ # todo switch everything to proper mixed precision like this
+ self.network.force_to(self.device_torch, dtype=torch.float32)
+ # give network to sd so it can use it
+ self.sd.network = self.network
+ self.network._update_torch_multiplier()
+
+ self.network.apply_to(
+ text_encoder,
+ unet,
+ self.train_config.train_text_encoder,
+ self.train_config.train_unet
+ )
+
+ # we cannot merge in if quantized
+ if self.model_config.quantize:
+ # todo find a way around this
+ self.network.can_merge_in = False
+
+ if is_lorm:
+ self.network.is_lorm = True
+ # make sure it is on the right device
+ self.sd.unet.to(self.sd.device, dtype=dtype)
+ original_unet_param_count = count_parameters(self.sd.unet)
+ self.network.setup_lorm()
+ new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction()
+
+ print_lorm_extract_details(
+ start_num_params=original_unet_param_count,
+ end_num_params=new_unet_param_count,
+ num_replaced=len(self.network.get_all_modules()),
+ )
+
+ self.network.prepare_grad_etc(text_encoder, unet)
+ flush()
+
+ # LyCORIS doesnt have default_lr
+ config = {
+ 'text_encoder_lr': self.train_config.lr,
+ 'unet_lr': self.train_config.lr,
+ }
+ sig = inspect.signature(self.network.prepare_optimizer_params)
+ if 'default_lr' in sig.parameters:
+ config['default_lr'] = self.train_config.lr
+ if 'learning_rate' in sig.parameters:
+ config['learning_rate'] = self.train_config.lr
+ params_net = self.network.prepare_optimizer_params(
+ **config
+ )
+
+ params += params_net
+
+ if self.train_config.gradient_checkpointing:
+ self.network.enable_gradient_checkpointing()
+
+ lora_name = self.name
+ # need to adapt name so they are not mixed up
+ if self.named_lora:
+ lora_name = f"{lora_name}_LoRA"
+
+ latest_save_path = self.get_latest_save_path(lora_name)
+ extra_weights = None
+ if latest_save_path is not None:
+ self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
+ self.print(f"Loading from {latest_save_path}")
+ extra_weights = self.load_weights(latest_save_path)
+ self.network.multiplier = 1.0
+
+ if self.embed_config is not None:
+ # we are doing embedding training as well
+ self.embedding = Embedding(
+ sd=self.sd,
+ embed_config=self.embed_config
+ )
+ latest_save_path = self.get_latest_save_path(self.embed_config.trigger)
+ # load last saved weights
+ if latest_save_path is not None:
+ self.embedding.load_embedding_from_file(latest_save_path, self.device_torch)
+ if self.embedding.step > 1:
+ self.step_num = self.embedding.step
+ self.start_step = self.step_num
+
+ # self.step_num = self.embedding.step
+ # self.start_step = self.step_num
+ params.append({
+ 'params': list(self.embedding.get_trainable_params()),
+ 'lr': self.train_config.embedding_lr
+ })
+
+ flush()
+
+ if self.decorator_config is not None:
+ self.decorator = Decorator(
+ num_tokens=self.decorator_config.num_tokens,
+ token_size=4096 # t5xxl hidden size for flux
+ )
+ latest_save_path = self.get_latest_save_path()
+ # load last saved weights
+ if latest_save_path is not None:
+ state_dict = load_file(latest_save_path)
+ self.decorator.load_state_dict(state_dict)
+ self.load_training_state_from_metadata(path)
+
+ params.append({
+ 'params': list(self.decorator.parameters()),
+ 'lr': self.train_config.lr
+ })
+
+ # give it to the sd network
+ self.sd.decorator = self.decorator
+ self.decorator.to(self.device_torch, dtype=torch.float32)
+ self.decorator.train()
+
+ flush()
+
+ if self.adapter_config is not None:
+ self.setup_adapter()
+ if self.adapter_config.train:
+
+ if isinstance(self.adapter, IPAdapter):
+ # we have custom LR groups for IPAdapter
+ adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr)
+ for group in adapter_param_groups:
+ params.append(group)
+ else:
+ # set trainable params
+ params.append({
+ 'params': list(self.adapter.parameters()),
+ 'lr': self.train_config.adapter_lr
+ })
+
+ if self.train_config.gradient_checkpointing:
+ self.adapter.enable_gradient_checkpointing()
+ flush()
+
+ params = self.load_additional_training_modules(params)
+
+ else: # no network, embedding or adapter
+ # set the device state preset before getting params
+ self.sd.set_device_state(self.get_params_device_state_preset)
+
+ # params = self.get_params()
+ if len(params) == 0:
+ # will only return savable weights and ones with grad
+ params = self.sd.prepare_optimizer_params(
+ unet=self.train_config.train_unet,
+ text_encoder=self.train_config.train_text_encoder,
+ text_encoder_lr=self.train_config.lr,
+ unet_lr=self.train_config.lr,
+ default_lr=self.train_config.lr,
+ refiner=self.train_config.train_refiner and self.sd.refiner_unet is not None,
+ refiner_lr=self.train_config.refiner_lr,
+ )
+ # we may be using it for prompt injections
+ if self.adapter_config is not None and self.adapter is None:
+ self.setup_adapter()
+ flush()
+ ### HOOK ###
+ params = self.hook_add_extra_train_params(params)
+ self.params = params
+ # self.params = []
+
+ # for param in params:
+ # if isinstance(param, dict):
+ # self.params += param['params']
+ # else:
+ # self.params.append(param)
+
+ if self.train_config.start_step is not None:
+ self.step_num = self.train_config.start_step
+ self.start_step = self.step_num
+
+ optimizer_type = self.train_config.optimizer.lower()
+
+ # esure params require grad
+ self.ensure_params_requires_grad(force=True)
+ optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
+ optimizer_params=self.train_config.optimizer_params)
+ self.optimizer = optimizer
+
+ # set it to do paramiter swapping
+ if self.train_config.do_paramiter_swapping:
+ # only works for adafactor, but it should have thrown an error prior to this otherwise
+ self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor)
+
+ # check if it exists
+ optimizer_state_filename = f'optimizer.pt'
+ optimizer_state_file_path = os.path.join(self.save_root, optimizer_state_filename)
+ if os.path.exists(optimizer_state_file_path):
+ # try to load
+ # previous param groups
+ # previous_params = copy.deepcopy(optimizer.param_groups)
+ previous_lrs = []
+ for group in optimizer.param_groups:
+ previous_lrs.append(group['lr'])
+
+ try:
+ print(f"Loading optimizer state from {optimizer_state_file_path}")
+ optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
+ optimizer.load_state_dict(optimizer_state_dict)
+ del optimizer_state_dict
+ flush()
+ except Exception as e:
+ print(f"Failed to load optimizer state from {optimizer_state_file_path}")
+ print(e)
+
+ # update the optimizer LR from the params
+ print(f"Updating optimizer LR from params")
+ if len(previous_lrs) > 0:
+ for i, group in enumerate(optimizer.param_groups):
+ group['lr'] = previous_lrs[i]
+ group['initial_lr'] = previous_lrs[i]
+
+ # Update the learning rates if they changed
+ # optimizer.param_groups = previous_params
+
+ lr_scheduler_params = self.train_config.lr_scheduler_params
+
+ # make sure it had bare minimum
+ if 'max_iterations' not in lr_scheduler_params:
+ lr_scheduler_params['total_iters'] = self.train_config.steps
+
+ lr_scheduler = get_lr_scheduler(
+ self.train_config.lr_scheduler,
+ optimizer,
+ **lr_scheduler_params
+ )
+ self.lr_scheduler = lr_scheduler
+
+ ### HOOk ###
+ self.before_dataset_load()
+ # load datasets if passed in the root process
+ if self.datasets is not None:
+ self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
+ if self.datasets_reg is not None:
+ self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
+ self.sd)
+
+ flush()
+ ### HOOK ###
+ self.hook_before_train_loop()
+
+ if self.has_first_sample_requested and self.step_num <= 1 and not self.train_config.disable_sampling:
+ self.print("Generating first sample from first sample config")
+ self.sample(0, is_first=True)
+
+ # sample first
+ if self.train_config.skip_first_sample or self.train_config.disable_sampling:
+ self.print("Skipping first sample due to config setting")
+ elif self.step_num <= 1 or self.train_config.force_first_sample:
+ self.print("Generating baseline samples before training")
+ self.sample(self.step_num)
+
+ self.progress_bar = ToolkitProgressBar(
+ total=self.train_config.steps,
+ desc=self.job.name,
+ leave=True,
+ initial=self.step_num,
+ iterable=range(0, self.train_config.steps),
+ )
+ self.progress_bar.pause()
+
+ if self.data_loader is not None:
+ dataloader = self.data_loader
+ dataloader_iterator = iter(dataloader)
+ else:
+ dataloader = None
+ dataloader_iterator = None
+
+ if self.data_loader_reg is not None:
+ dataloader_reg = self.data_loader_reg
+ dataloader_iterator_reg = iter(dataloader_reg)
+ else:
+ dataloader_reg = None
+ dataloader_iterator_reg = None
+
+ # zero any gradients
+ optimizer.zero_grad()
+
+ self.lr_scheduler.step(self.step_num)
+
+ self.sd.set_device_state(self.train_device_state_preset)
+ flush()
+ # self.step_num = 0
+
+ # print(f"Compiling Model")
+ # torch.compile(self.sd.unet, dynamic=True)
+
+ # make sure all params require grad
+ self.ensure_params_requires_grad(force=True)
+
+
+ ###################################################################
+ # TRAIN LOOP
+ ###################################################################
+
+
+ start_step_num = self.step_num
+ did_first_flush = False
+ for step in range(start_step_num, self.train_config.steps):
+ if self.train_config.do_paramiter_swapping:
+ self.optimizer.swap_paramiters()
+ self.timer.start('train_loop')
+ if self.train_config.do_random_cfg:
+ self.train_config.do_cfg = True
+ self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
+ self.step_num = step
+ # default to true so various things can turn it off
+ self.is_grad_accumulation_step = True
+ if self.train_config.free_u:
+ self.sd.pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
+ self.progress_bar.unpause()
+ with torch.no_grad():
+ # if is even step and we have a reg dataset, use that
+ # todo improve this logic to send one of each through if we can buckets and batch size might be an issue
+ is_reg_step = False
+ is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0
+ is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0
+ if self.train_config.disable_sampling:
+ is_sample_step = False
+
+ batch_list = []
+
+ for b in range(self.train_config.gradient_accumulation):
+ # keep track to alternate on an accumulation step for reg
+ batch_step = step
+ # don't do a reg step on sample or save steps as we dont want to normalize on those
+ if batch_step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step:
+ try:
+ with self.timer('get_batch:reg'):
+ batch = next(dataloader_iterator_reg)
+ except StopIteration:
+ with self.timer('reset_batch:reg'):
+ # hit the end of an epoch, reset
+ self.progress_bar.pause()
+ dataloader_iterator_reg = iter(dataloader_reg)
+ trigger_dataloader_setup_epoch(dataloader_reg)
+
+ with self.timer('get_batch:reg'):
+ batch = next(dataloader_iterator_reg)
+ self.progress_bar.unpause()
+ is_reg_step = True
+ elif dataloader is not None:
+ try:
+ with self.timer('get_batch'):
+ batch = next(dataloader_iterator)
+ except StopIteration:
+ with self.timer('reset_batch'):
+ # hit the end of an epoch, reset
+ self.progress_bar.pause()
+ dataloader_iterator = iter(dataloader)
+ trigger_dataloader_setup_epoch(dataloader)
+ self.epoch_num += 1
+ if self.train_config.gradient_accumulation_steps == -1:
+ # if we are accumulating for an entire epoch, trigger a step
+ self.is_grad_accumulation_step = False
+ self.grad_accumulation_step = 0
+ with self.timer('get_batch'):
+ batch = next(dataloader_iterator)
+ self.progress_bar.unpause()
+ else:
+ batch = None
+ batch_list.append(batch)
+ batch_step += 1
+
+ # setup accumulation
+ if self.train_config.gradient_accumulation_steps == -1:
+ # epoch is handling the accumulation, dont touch it
+ pass
+ else:
+ # determine if we are accumulating or not
+ # since optimizer step happens in the loop, we trigger it a step early
+ # since we cannot reprocess it before them
+ optimizer_step_at = self.train_config.gradient_accumulation_steps
+ is_optimizer_step = self.grad_accumulation_step >= optimizer_step_at
+ self.is_grad_accumulation_step = not is_optimizer_step
+ if is_optimizer_step:
+ self.grad_accumulation_step = 0
+
+ # flush()
+ ### HOOK ###
+
+ loss_dict = self.hook_train_loop(batch_list)
+ self.timer.stop('train_loop')
+ if not did_first_flush:
+ flush()
+ did_first_flush = True
+ # flush()
+ # setup the networks to gradient checkpointing and everything works
+ if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
+ self.adapter.clear_memory()
+
+ with torch.no_grad():
+ # torch.cuda.empty_cache()
+ # if optimizer has get_lrs method, then use it
+ if hasattr(optimizer, 'get_avg_learning_rate'):
+ learning_rate = optimizer.get_avg_learning_rate()
+ elif hasattr(optimizer, 'get_learning_rates'):
+ learning_rate = optimizer.get_learning_rates()[0]
+ elif self.train_config.optimizer.lower().startswith('dadaptation') or \
+ self.train_config.optimizer.lower().startswith('prodigy'):
+ learning_rate = (
+ optimizer.param_groups[0]["d"] *
+ optimizer.param_groups[0]["lr"]
+ )
+ else:
+ learning_rate = optimizer.param_groups[0]['lr']
+
+ prog_bar_string = f"lr: {learning_rate:.1e}"
+ for key, value in loss_dict.items():
+ prog_bar_string += f" {key}: {value:.3e}"
+
+ self.progress_bar.set_postfix_str(prog_bar_string)
+
+ # if the batch is a DataLoaderBatchDTO, then we need to clean it up
+ if isinstance(batch, DataLoaderBatchDTO):
+ with self.timer('batch_cleanup'):
+ batch.cleanup()
+
+ # don't do on first step
+ if self.step_num != self.start_step:
+ if is_sample_step:
+ self.progress_bar.pause()
+ flush()
+ # print above the progress bar
+ if self.train_config.free_u:
+ self.sd.pipeline.disable_freeu()
+ self.sample(self.step_num)
+ if self.train_config.unload_text_encoder:
+ # make sure the text encoder is unloaded
+ self.sd.text_encoder_to('cpu')
+ flush()
+
+ self.ensure_params_requires_grad()
+ self.progress_bar.unpause()
+
+ if is_save_step:
+ # print above the progress bar
+ self.progress_bar.pause()
+ self.print(f"Saving at step {self.step_num}")
+ self.save(self.step_num)
+ self.ensure_params_requires_grad()
+ self.progress_bar.unpause()
+
+ if self.logging_config.log_every and self.step_num % self.logging_config.log_every == 0:
+ self.progress_bar.pause()
+ with self.timer('log_to_tensorboard'):
+ # log to tensorboard
+ if self.writer is not None:
+ for key, value in loss_dict.items():
+ self.writer.add_scalar(f"{key}", value, self.step_num)
+ self.writer.add_scalar(f"lr", learning_rate, self.step_num)
+ self.progress_bar.unpause()
+
+ # log to logger
+ self.logger.log({
+ 'learning_rate': learning_rate,
+ })
+ for key, value in loss_dict.items():
+ self.logger.log({
+ f'loss/{key}': value,
+ })
+ elif self.logging_config.log_every is None:
+ # log every step
+ self.logger.log({
+ 'learning_rate': learning_rate,
+ })
+ for key, value in loss_dict.items():
+ self.logger.log({
+ f'loss/{key}': value,
+ })
+
+
+ if self.performance_log_every > 0 and self.step_num % self.performance_log_every == 0:
+ self.progress_bar.pause()
+ # print the timers and clear them
+ self.timer.print()
+ self.timer.reset()
+ self.progress_bar.unpause()
+
+ # commit log
+ self.logger.commit(step=self.step_num)
+
+ # sets progress bar to match out step
+ self.progress_bar.update(step - self.progress_bar.n)
+
+ #############################
+ # End of step
+ #############################
+
+ # update various steps
+ self.step_num = step + 1
+ self.grad_accumulation_step += 1
+
+
+ ###################################################################
+ ## END TRAIN LOOP
+ ###################################################################
+
+ self.progress_bar.close()
+ if self.train_config.free_u:
+ self.sd.pipeline.disable_freeu()
+ if not self.train_config.disable_sampling:
+ self.sample(self.step_num)
+ self.logger.commit(step=self.step_num)
+ print("")
+ self.save()
+ self.logger.finish()
+
+ if self.save_config.push_to_hub:
+ if("HF_TOKEN" not in os.environ):
+ interpreter_login(new_session=False, write_permission=True)
+ self.push_to_hub(
+ repo_id=self.save_config.hf_repo_id,
+ private=self.save_config.hf_private
+ )
+ del (
+ self.sd,
+ unet,
+ noise_scheduler,
+ optimizer,
+ self.network,
+ tokenizer,
+ text_encoder,
+ )
+
+ flush()
+
+ def push_to_hub(
+ self,
+ repo_id: str,
+ private: bool = False,
+ ):
+ readme_content = self._generate_readme(repo_id)
+ readme_path = os.path.join(self.save_root, "README.md")
+ with open(readme_path, "w", encoding="utf-8") as f:
+ f.write(readme_content)
+
+ api = HfApi()
+
+ api.create_repo(
+ repo_id,
+ private=private,
+ exist_ok=True
+ )
+
+ api.upload_folder(
+ repo_id=repo_id,
+ folder_path=self.save_root,
+ ignore_patterns=["*.yaml", "*.pt"],
+ repo_type="model",
+ )
+
+
+ def _generate_readme(self, repo_id: str) -> str:
+ """Generates the content of the README.md file."""
+
+ # Gather model info
+ base_model = self.model_config.name_or_path
+ instance_prompt = self.trigger_word if hasattr(self, "trigger_word") else None
+ if base_model == "black-forest-labs/FLUX.1-schnell":
+ license = "apache-2.0"
+ elif base_model == "black-forest-labs/FLUX.1-dev":
+ license = "other"
+ license_name = "flux-1-dev-non-commercial-license"
+ license_link = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
+ else:
+ license = "creativeml-openrail-m"
+ tags = [
+ "text-to-image",
+ ]
+ if self.model_config.is_xl:
+ tags.append("stable-diffusion-xl")
+ if self.model_config.is_flux:
+ tags.append("flux")
+ if self.model_config.is_v3:
+ tags.append("sd3")
+ if self.network_config:
+ tags.extend(
+ [
+ "lora",
+ "diffusers",
+ "template:sd-lora",
+ "ai-toolkit",
+ ]
+ )
+
+ # Generate the widget section
+ widgets = []
+ sample_image_paths = []
+ samples_dir = os.path.join(self.save_root, "samples")
+ if os.path.isdir(samples_dir):
+ for filename in os.listdir(samples_dir):
+ #The filenames are structured as 1724085406830__00000500_0.jpg
+ #So here we capture the 2nd part (steps) and 3rd (index the matches the prompt)
+ match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
+ if match:
+ steps, index = int(match.group(1)), int(match.group(2))
+ #Here we only care about uploading the latest samples, the match with the # of steps
+ if steps == self.train_config.steps:
+ sample_image_paths.append((index, f"samples/{filename}"))
+
+ # Sort by numeric index
+ sample_image_paths.sort(key=lambda x: x[0])
+
+ # Create widgets matching prompt with the index
+ for i, prompt in enumerate(self.sample_config.prompts):
+ if i < len(sample_image_paths):
+ # Associate prompts with sample image paths based on the extracted index
+ _, image_path = sample_image_paths[i]
+ widgets.append(
+ {
+ "text": prompt,
+ "output": {
+ "url": image_path
+ },
+ }
+ )
+ dtype = "torch.bfloat16" if self.model_config.is_flux else "torch.float16"
+ # Construct the README content
+ readme_content = f"""---
+tags:
+{yaml.dump(tags, indent=4).strip()}
+{"widget:" if os.path.isdir(samples_dir) else ""}
+{yaml.dump(widgets, indent=4).strip() if widgets else ""}
+base_model: {base_model}
+{"instance_prompt: " + instance_prompt if instance_prompt else ""}
+license: {license}
+{'license_name: ' + license_name if license == "other" else ""}
+{'license_link: ' + license_link if license == "other" else ""}
+---
+
+# {self.job.name}
+Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
+
+
+## Trigger words
+
+{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
+
+## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
+
+Weights for this model are available in Safetensors format.
+
+[Download](/{repo_id}/tree/main) them in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='{self.job.name}.safetensors')
+image = pipeline('{instance_prompt if not widgets else self.sample_config.prompts[0]}').images[0]
+image.save("my_image.png")
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+"""
+ return readme_content
diff --git a/jobs/process/BaseTrainProcess.py b/jobs/process/BaseTrainProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1885de23930ab372f26efa2b2281c9a9332f4fe
--- /dev/null
+++ b/jobs/process/BaseTrainProcess.py
@@ -0,0 +1,79 @@
+import random
+from datetime import datetime
+import os
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Union
+
+import torch
+import yaml
+
+from jobs.process.BaseProcess import BaseProcess
+
+if TYPE_CHECKING:
+ from jobs import TrainJob, BaseJob, ExtensionJob
+ from torch.utils.tensorboard import SummaryWriter
+ from tqdm import tqdm
+
+
+class BaseTrainProcess(BaseProcess):
+
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.process_id: int
+ self.config: OrderedDict
+ self.writer: 'SummaryWriter'
+ self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob']
+ self.progress_bar: 'tqdm' = None
+
+ self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None)
+ # if training seed is set, use it
+ if self.training_seed is not None:
+ torch.manual_seed(self.training_seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(self.training_seed)
+ random.seed(self.training_seed)
+
+ self.progress_bar = None
+ self.writer = None
+ self.training_folder = self.get_conf('training_folder',
+ self.job.training_folder if hasattr(self.job, 'training_folder') else None)
+ self.save_root = os.path.join(self.training_folder, self.name)
+ self.step = 0
+ self.first_step = 0
+ self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None)
+ self.setup_tensorboard()
+ self.save_training_config()
+
+ def run(self):
+ super().run()
+ # implement in child class
+ # be sure to call super().run() first
+ pass
+
+ # def print(self, message, **kwargs):
+ def print(self, *args):
+ if self.progress_bar is not None:
+ self.progress_bar.write(' '.join(map(str, args)))
+ self.progress_bar.update()
+ else:
+ print(*args)
+
+ def setup_tensorboard(self):
+ if self.log_dir:
+ from torch.utils.tensorboard import SummaryWriter
+ now = datetime.now()
+ time_str = now.strftime('%Y%m%d-%H%M%S')
+ summary_name = f"{self.name}_{time_str}"
+ summary_dir = os.path.join(self.log_dir, summary_name)
+ self.writer = SummaryWriter(summary_dir)
+
+ def save_training_config(self):
+ os.makedirs(self.save_root, exist_ok=True)
+ save_dif = os.path.join(self.save_root, f'config.yaml')
+ with open(save_dif, 'w') as f:
+ yaml.dump(self.job.raw_config, f)
diff --git a/jobs/process/ExtractLoconProcess.py b/jobs/process/ExtractLoconProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5dac5edd7bcc5fb959fb4a3717bfa975d1264cc
--- /dev/null
+++ b/jobs/process/ExtractLoconProcess.py
@@ -0,0 +1,68 @@
+from collections import OrderedDict
+from toolkit.lycoris_utils import extract_diff
+from .BaseExtractProcess import BaseExtractProcess
+
+mode_dict = {
+ 'fixed': {
+ 'linear': 64,
+ 'conv': 32,
+ 'type': int
+ },
+ 'threshold': {
+ 'linear': 0,
+ 'conv': 0,
+ 'type': float
+ },
+ 'ratio': {
+ 'linear': 0.5,
+ 'conv': 0.5,
+ 'type': float
+ },
+ 'quantile': {
+ 'linear': 0.5,
+ 'conv': 0.5,
+ 'type': float
+ }
+}
+
+
+class ExtractLoconProcess(BaseExtractProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.mode = self.get_conf('mode', 'fixed')
+ self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
+ self.sparsity = self.get_conf('sparsity', 0.98)
+ self.disable_cp = self.get_conf('disable_cp', False)
+
+ # set modes
+ if self.mode not in list(mode_dict.keys()):
+ raise ValueError(f"Unknown mode: {self.mode}")
+ self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
+ self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type'])
+
+ def run(self):
+ super().run()
+ print(f"Running process: {self.mode}, lin: {self.linear_param}, conv: {self.conv_param}")
+
+ state_dict, extract_diff_meta = extract_diff(
+ self.job.model_base,
+ self.job.model_extract,
+ self.mode,
+ self.linear_param,
+ self.conv_param,
+ self.job.device,
+ self.use_sparse_bias,
+ self.sparsity,
+ not self.disable_cp,
+ extract_unet=self.extract_unet,
+ extract_text_encoder=self.extract_text_encoder
+ )
+
+ self.add_meta(extract_diff_meta)
+ self.save(state_dict)
+
+ def get_output_path(self, prefix=None, suffix=None):
+ if suffix is None:
+ suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
+ return super().get_output_path(prefix, suffix)
+
diff --git a/jobs/process/ExtractLoraProcess.py b/jobs/process/ExtractLoraProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f0cc942b0c6d76139223851965e643dfb31376
--- /dev/null
+++ b/jobs/process/ExtractLoraProcess.py
@@ -0,0 +1,73 @@
+from collections import OrderedDict
+from toolkit.lycoris_utils import extract_diff
+from .BaseExtractProcess import BaseExtractProcess
+
+
+mode_dict = {
+ 'fixed': {
+ 'linear': 4,
+ 'conv': 0,
+ 'type': int
+ },
+ 'threshold': {
+ 'linear': 0,
+ 'conv': 0,
+ 'type': float
+ },
+ 'ratio': {
+ 'linear': 0.5,
+ 'conv': 0,
+ 'type': float
+ },
+ 'quantile': {
+ 'linear': 0.5,
+ 'conv': 0,
+ 'type': float
+ }
+}
+
+CLAMP_QUANTILE = 0.99
+MIN_DIFF = 1e-6
+
+
+class ExtractLoraProcess(BaseExtractProcess):
+
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.mode = self.get_conf('mode', 'fixed')
+
+ # set modes
+ if self.mode not in list(mode_dict.keys()):
+ raise ValueError(f"Unknown mode: {self.mode}")
+ self.linear = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
+ self.linear_param = self.get_conf('linear', mode_dict[self.mode]['linear'], as_type=mode_dict[self.mode]['type'])
+ self.conv_param = self.get_conf('conv', mode_dict[self.mode]['conv'], as_type=mode_dict[self.mode]['type'])
+ self.use_sparse_bias = self.get_conf('use_sparse_bias', False)
+ self.sparsity = self.get_conf('sparsity', 0.98)
+
+ def run(self):
+ super().run()
+ print(f"Running process: {self.mode}, dim: {self.dim}")
+
+ state_dict, extract_diff_meta = extract_diff(
+ self.job.model_base,
+ self.job.model_extract,
+ self.mode,
+ self.linear_param,
+ self.conv_param,
+ self.job.device,
+ self.use_sparse_bias,
+ self.sparsity,
+ small_conv=False,
+ linear_only=self.conv_param > 0.0000000001,
+ extract_unet=self.extract_unet,
+ extract_text_encoder=self.extract_text_encoder
+ )
+
+ self.add_meta(extract_diff_meta)
+ self.save(state_dict)
+
+ def get_output_path(self, prefix=None, suffix=None):
+ if suffix is None:
+ suffix = f"_{self.dim}"
+ return super().get_output_path(prefix, suffix)
diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0cb32d8e0d1cdc6bdc1f723bf73495fae14c809
--- /dev/null
+++ b/jobs/process/GenerateProcess.py
@@ -0,0 +1,146 @@
+import gc
+import os
+from collections import OrderedDict
+from typing import ForwardRef, List, Optional, Union
+
+import torch
+from safetensors.torch import save_file, load_file
+
+from jobs.process.BaseProcess import BaseProcess
+from toolkit.config_modules import ModelConfig, GenerateImageConfig
+from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
+ add_base_model_info_to_meta
+from toolkit.stable_diffusion_model import StableDiffusion
+from toolkit.train_tools import get_torch_dtype
+import random
+
+
+class GenerateConfig:
+
+ def __init__(self, **kwargs):
+ self.prompts: List[str]
+ self.sampler = kwargs.get('sampler', 'ddpm')
+ self.width = kwargs.get('width', 512)
+ self.height = kwargs.get('height', 512)
+ self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
+ self.neg = kwargs.get('neg', '')
+ self.seed = kwargs.get('seed', -1)
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
+ self.sample_steps = kwargs.get('sample_steps', 20)
+ self.prompt_2 = kwargs.get('prompt_2', None)
+ self.neg_2 = kwargs.get('neg_2', None)
+ self.prompts = kwargs.get('prompts', None)
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
+ self.compile = kwargs.get('compile', False)
+ self.ext = kwargs.get('ext', 'png')
+ self.prompt_file = kwargs.get('prompt_file', False)
+ self.num_repeats = kwargs.get('num_repeats', 1)
+ self.prompts_in_file = self.prompts
+ if self.prompts is None:
+ raise ValueError("Prompts must be set")
+ if isinstance(self.prompts, str):
+ if os.path.exists(self.prompts):
+ with open(self.prompts, 'r', encoding='utf-8') as f:
+ self.prompts_in_file = f.read().splitlines()
+ self.prompts_in_file = [p.strip() for p in self.prompts_in_file if len(p.strip()) > 0]
+ else:
+ raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
+
+ self.random_prompts = kwargs.get('random_prompts', False)
+ self.max_random_per_prompt = kwargs.get('max_random_per_prompt', 1)
+ self.max_images = kwargs.get('max_images', 10000)
+
+ if self.random_prompts:
+ self.prompts = []
+ for i in range(self.max_images):
+ num_prompts = random.randint(1, self.max_random_per_prompt)
+ prompt_list = [random.choice(self.prompts_in_file) for _ in range(num_prompts)]
+ self.prompts.append(", ".join(prompt_list))
+ else:
+ self.prompts = self.prompts_in_file
+
+ if kwargs.get('shuffle', False):
+ # shuffle the prompts
+ random.shuffle(self.prompts)
+
+
+class GenerateProcess(BaseProcess):
+ process_id: int
+ config: OrderedDict
+ progress_bar: ForwardRef('tqdm') = None
+ sd: StableDiffusion
+
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.output_folder = self.get_conf('output_folder', required=True)
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
+ self.device = self.get_conf('device', self.job.device)
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
+ self.torch_dtype = get_torch_dtype(self.get_conf('dtype', 'float16'))
+
+ self.progress_bar = None
+ self.sd = StableDiffusion(
+ device=self.device,
+ model_config=self.model_config,
+ dtype=self.model_config.dtype,
+ )
+
+ print(f"Using device {self.device}")
+
+ def clean_prompt(self, prompt: str):
+ # remove any non alpha numeric characters or ,'" from prompt
+ return ''.join(e for e in prompt if e.isalnum() or e in ", '\"")
+
+ def run(self):
+ with torch.no_grad():
+ super().run()
+ print("Loading model...")
+ self.sd.load_model()
+ self.sd.pipeline.to(self.device, self.torch_dtype)
+
+ print("Compiling model...")
+ # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
+ if self.generate_config.compile:
+ self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
+
+ print(f"Generating {len(self.generate_config.prompts)} images")
+ # build prompt image configs
+ prompt_image_configs = []
+ for _ in range(self.generate_config.num_repeats):
+ for prompt in self.generate_config.prompts:
+ width = self.generate_config.width
+ height = self.generate_config.height
+ # prompt = self.clean_prompt(prompt)
+
+ if self.generate_config.size_list is not None:
+ # randomly select a size
+ width, height = random.choice(self.generate_config.size_list)
+
+ prompt_image_configs.append(GenerateImageConfig(
+ prompt=prompt,
+ prompt_2=self.generate_config.prompt_2,
+ width=width,
+ height=height,
+ num_inference_steps=self.generate_config.sample_steps,
+ guidance_scale=self.generate_config.guidance_scale,
+ negative_prompt=self.generate_config.neg,
+ negative_prompt_2=self.generate_config.neg_2,
+ seed=self.generate_config.seed,
+ guidance_rescale=self.generate_config.guidance_rescale,
+ output_ext=self.generate_config.ext,
+ output_folder=self.output_folder,
+ add_prompt_file=self.generate_config.prompt_file
+ ))
+ # generate images
+ self.sd.generate_images(prompt_image_configs, sampler=self.generate_config.sampler)
+
+ print("Done generating images")
+ # cleanup
+ del self.sd
+ gc.collect()
+ torch.cuda.empty_cache()
diff --git a/jobs/process/MergeLoconProcess.py b/jobs/process/MergeLoconProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..00c70cd2abdbc894f7b00c6cbf51a3dcfcc95531
--- /dev/null
+++ b/jobs/process/MergeLoconProcess.py
@@ -0,0 +1,20 @@
+from collections import OrderedDict
+from toolkit.lycoris_utils import extract_diff
+from .BaseExtractProcess import BaseExtractProcess
+
+
+class MergeLoconProcess(BaseExtractProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+
+ def run(self):
+ super().run()
+ new_state_dict = {}
+ raise NotImplementedError("This is not implemented yet")
+
+
+ def get_output_path(self, prefix=None, suffix=None):
+ if suffix is None:
+ suffix = f"_{self.mode}_{self.linear_param}_{self.conv_param}"
+ return super().get_output_path(prefix, suffix)
+
diff --git a/jobs/process/ModRescaleLoraProcess.py b/jobs/process/ModRescaleLoraProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb7436098f95ed774c7f1febc1b8bd7c0791981
--- /dev/null
+++ b/jobs/process/ModRescaleLoraProcess.py
@@ -0,0 +1,104 @@
+import gc
+import os
+from collections import OrderedDict
+from typing import ForwardRef
+
+import torch
+from safetensors.torch import save_file, load_file
+
+from jobs.process.BaseProcess import BaseProcess
+from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_model_hash_to_meta, \
+ add_base_model_info_to_meta
+from toolkit.train_tools import get_torch_dtype
+
+
+class ModRescaleLoraProcess(BaseProcess):
+ process_id: int
+ config: OrderedDict
+ progress_bar: ForwardRef('tqdm') = None
+
+ def __init__(
+ self,
+ process_id: int,
+ job,
+ config: OrderedDict
+ ):
+ super().__init__(process_id, job, config)
+ self.process_id: int
+ self.config: OrderedDict
+ self.progress_bar: ForwardRef('tqdm') = None
+ self.input_path = self.get_conf('input_path', required=True)
+ self.output_path = self.get_conf('output_path', required=True)
+ self.replace_meta = self.get_conf('replace_meta', default=False)
+ self.save_dtype = self.get_conf('save_dtype', default='fp16', as_type=get_torch_dtype)
+ self.current_weight = self.get_conf('current_weight', required=True, as_type=float)
+ self.target_weight = self.get_conf('target_weight', required=True, as_type=float)
+ self.scale_target = self.get_conf('scale_target', default='up_down') # alpha or up_down
+ self.is_xl = self.get_conf('is_xl', default=False, as_type=bool)
+ self.is_v2 = self.get_conf('is_v2', default=False, as_type=bool)
+
+ self.progress_bar = None
+
+ def run(self):
+ super().run()
+ source_state_dict = load_file(self.input_path)
+ source_meta = load_metadata_from_safetensors(self.input_path)
+
+ if self.replace_meta:
+ self.meta.update(
+ add_base_model_info_to_meta(
+ self.meta,
+ is_xl=self.is_xl,
+ is_v2=self.is_v2,
+ )
+ )
+ save_meta = get_meta_for_safetensors(self.meta, self.job.name)
+ else:
+ save_meta = get_meta_for_safetensors(source_meta, self.job.name, add_software_info=False)
+
+ # save
+ os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
+
+ new_state_dict = OrderedDict()
+
+ for key in list(source_state_dict.keys()):
+ v = source_state_dict[key]
+ v = v.detach().clone().to("cpu").to(get_torch_dtype('fp32'))
+
+ # all loras have an alpha, up weight and down weight
+ # - "lora_te_text_model_encoder_layers_0_mlp_fc1.alpha",
+ # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight",
+ # - "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight",
+ # we can rescale by adjusting the alpha or the up weights, or the up and down weights
+ # I assume doing both up and down would be best all around, but I'm not sure
+ # some locons also have mid weights, we will leave those alone for now, will work without them
+
+ # when adjusting alpha, it is used to calculate the multiplier in a lora module
+ # - scale = alpha / lora_dim
+ # - output = layer_out + lora_up_out * multiplier * scale
+ total_module_scale = torch.tensor(self.current_weight / self.target_weight) \
+ .to("cpu", dtype=get_torch_dtype('fp32'))
+ num_modules_layers = 2 # up and down
+ up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
+ .to("cpu", dtype=get_torch_dtype('fp32'))
+ # only update alpha
+ if self.scale_target == 'alpha' and key.endswith('.alpha'):
+ v = v * total_module_scale
+ if self.scale_target == 'up_down' and key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
+ # would it be better to adjust the up weights for fp16 precision? Doing both should reduce chance of NaN
+ v = v * up_down_scale
+ v = v.detach().clone().to("cpu").to(self.save_dtype)
+ new_state_dict[key] = v
+
+ save_meta = add_model_hash_to_meta(new_state_dict, save_meta)
+ save_file(new_state_dict, self.output_path, save_meta)
+
+ # cleanup incase there are other jobs
+ del new_state_dict
+ del source_state_dict
+ del source_meta
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ print(f"Saved to {self.output_path}")
diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff3a69d89260396232f6085d161db3afe26b668
--- /dev/null
+++ b/jobs/process/TrainESRGANProcess.py
@@ -0,0 +1,657 @@
+import copy
+import glob
+import os
+import time
+from collections import OrderedDict
+from typing import List, Optional
+
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+
+from toolkit.basic import flush
+from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
+from safetensors.torch import save_file, load_file
+from torch.utils.data import DataLoader, ConcatDataset
+import torch
+from torch import nn
+from torchvision.transforms import transforms
+
+from jobs.process import BaseTrainProcess
+from toolkit.data_loader import AugmentedImageDataset
+from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format
+from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.optimizer import get_optimizer
+from toolkit.style import get_style_model_and_losses
+from toolkit.train_tools import get_torch_dtype
+from diffusers import AutoencoderKL
+from tqdm import tqdm
+import time
+import numpy as np
+from .models.vgg19_critic import Critic
+
+IMAGE_TRANSFORMS = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ # transforms.Normalize([0.5], [0.5]),
+ ]
+)
+
+
+class TrainESRGANProcess(BaseTrainProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.data_loader = None
+ self.model: ESRGAN = None
+ self.device = self.get_conf('device', self.job.device)
+ self.pretrained_path = self.get_conf('pretrained_path', 'None')
+ self.datasets_objects = self.get_conf('datasets', required=True)
+ self.batch_size = self.get_conf('batch_size', 1, as_type=int)
+ self.resolution = self.get_conf('resolution', 256, as_type=int)
+ self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
+ self.sample_every = self.get_conf('sample_every', None)
+ self.optimizer_type = self.get_conf('optimizer', 'adam')
+ self.epochs = self.get_conf('epochs', None, as_type=int)
+ self.max_steps = self.get_conf('max_steps', None, as_type=int)
+ self.save_every = self.get_conf('save_every', None)
+ self.upscale_sample = self.get_conf('upscale_sample', 4)
+ self.dtype = self.get_conf('dtype', 'float32')
+ self.sample_sources = self.get_conf('sample_sources', None)
+ self.log_every = self.get_conf('log_every', 100, as_type=int)
+ self.style_weight = self.get_conf('style_weight', 0, as_type=float)
+ self.content_weight = self.get_conf('content_weight', 0, as_type=float)
+ self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
+ self.zoom = self.get_conf('zoom', 4, as_type=int)
+ self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
+ self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
+ self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
+ self.optimizer_params = self.get_conf('optimizer_params', {})
+ self.augmentations = self.get_conf('augmentations', {})
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ if self.torch_dtype == torch.bfloat16:
+ self.esrgan_dtype = torch.float32
+ else:
+ self.esrgan_dtype = torch.float32
+
+ self.vgg_19 = None
+ self.style_weight_scalers = []
+ self.content_weight_scalers = []
+
+ # throw error if zoom if not divisible by 2
+ if self.zoom % 2 != 0:
+ raise ValueError('zoom must be divisible by 2')
+
+ self.step_num = 0
+ self.epoch_num = 0
+
+ self.use_critic = self.get_conf('use_critic', False, as_type=bool)
+ self.critic = None
+
+ if self.use_critic:
+ self.critic = Critic(
+ device=self.device,
+ dtype=self.dtype,
+ process=self,
+ **self.get_conf('critic', {}) # pass any other params
+ )
+
+ if self.sample_every is not None and self.sample_sources is None:
+ raise ValueError('sample_every is specified but sample_sources is not')
+
+ if self.epochs is None and self.max_steps is None:
+ raise ValueError('epochs or max_steps must be specified')
+
+ self.data_loaders = []
+ # check datasets
+ assert isinstance(self.datasets_objects, list)
+ for dataset in self.datasets_objects:
+ if 'path' not in dataset:
+ raise ValueError('dataset must have a path')
+ # check if is dir
+ if not os.path.isdir(dataset['path']):
+ raise ValueError(f"dataset path does is not a directory: {dataset['path']}")
+
+ # make training folder
+ if not os.path.exists(self.save_root):
+ os.makedirs(self.save_root, exist_ok=True)
+
+ self._pattern_loss = None
+
+ # build augmentation transforms
+ aug_transforms = []
+
+ def update_training_metadata(self):
+ self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
+
+ def get_training_info(self):
+ info = OrderedDict({
+ 'step': self.step_num,
+ 'epoch': self.epoch_num,
+ })
+ return info
+
+ def load_datasets(self):
+ if self.data_loader is None:
+ print(f"Loading datasets")
+ datasets = []
+ for dataset in self.datasets_objects:
+ print(f" - Dataset: {dataset['path']}")
+ ds = copy.copy(dataset)
+ ds['resolution'] = self.resolution
+
+ if 'augmentations' not in ds:
+ ds['augmentations'] = self.augmentations
+
+ # add the resize down augmentation
+ ds['augmentations'] = [{
+ 'method': 'Resize',
+ 'params': {
+ 'width': int(self.resolution // self.zoom),
+ 'height': int(self.resolution // self.zoom),
+ # downscale interpolation, string will be evaluated
+ 'interpolation': 'cv2.INTER_AREA'
+ }
+ }] + ds['augmentations']
+
+ image_dataset = AugmentedImageDataset(ds)
+ datasets.append(image_dataset)
+
+ concatenated_dataset = ConcatDataset(datasets)
+ self.data_loader = DataLoader(
+ concatenated_dataset,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=6
+ )
+
+ def setup_vgg19(self):
+ if self.vgg_19 is None:
+ self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
+ single_target=True,
+ device=self.device,
+ output_layer_name='pool_4',
+ dtype=self.torch_dtype
+ )
+ self.vgg_19.to(self.device, dtype=self.torch_dtype)
+ self.vgg_19.requires_grad_(False)
+
+ # we run random noise through first to get layer scalers to normalize the loss per layer
+ # bs of 2 because we run pred and target through stacked
+ noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
+ self.vgg_19(noise)
+ for style_loss in self.style_losses:
+ # get a scaler to normalize to 1
+ scaler = 1 / torch.mean(style_loss.loss).item()
+ self.style_weight_scalers.append(scaler)
+ for content_loss in self.content_losses:
+ # get a scaler to normalize to 1
+ scaler = 1 / torch.mean(content_loss.loss).item()
+ # if is nan, set to 1
+ if scaler != scaler:
+ scaler = 1
+ print(f"Warning: content loss scaler is nan, setting to 1")
+ self.content_weight_scalers.append(scaler)
+
+ self.print(f"Style weight scalers: {self.style_weight_scalers}")
+ self.print(f"Content weight scalers: {self.content_weight_scalers}")
+
+ def get_style_loss(self):
+ if self.style_weight > 0:
+ # scale all losses with loss scalers
+ loss = torch.sum(
+ torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_content_loss(self):
+ if self.content_weight > 0:
+ # scale all losses with loss scalers
+ loss = torch.sum(torch.stack(
+ [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_mse_loss(self, pred, target):
+ if self.mse_weight > 0:
+ loss_fn = nn.MSELoss()
+ loss = loss_fn(pred, target)
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_tv_loss(self, pred, target):
+ if self.tv_weight > 0:
+ get_tv_loss = ComparativeTotalVariation()
+ loss = get_tv_loss(pred, target)
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_pattern_loss(self, pred, target):
+ if self._pattern_loss is None:
+ self._pattern_loss = PatternLoss(
+ pattern_size=self.zoom,
+ dtype=self.torch_dtype
+ ).to(self.device, dtype=self.torch_dtype)
+ self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype)
+ loss = torch.mean(self._pattern_loss(pred, target))
+ return loss
+
+ def save(self, step=None):
+ if not os.path.exists(self.save_root):
+ os.makedirs(self.save_root, exist_ok=True)
+
+ step_num = ''
+ if step is not None:
+ # zeropad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+
+ self.update_training_metadata()
+ # filename = f'{self.job.name}{step_num}.safetensors'
+ filename = f'{self.job.name}{step_num}.pth'
+ # prepare meta
+ save_meta = get_meta_for_safetensors(self.meta, self.job.name)
+
+ # state_dict = self.model.state_dict()
+
+ # state has the original state dict keys so we can save what we started from
+ save_state_dict = self.model.state_dict()
+
+ for key in list(save_state_dict.keys()):
+ v = save_state_dict[key]
+ v = v.detach().clone().to("cpu").to(torch.float32)
+ save_state_dict[key] = v
+
+ # most things wont use safetensors, save as torch
+ # save_file(save_state_dict, os.path.join(self.save_root, filename), save_meta)
+ torch.save(save_state_dict, os.path.join(self.save_root, filename))
+
+ self.print(f"Saved to {os.path.join(self.save_root, filename)}")
+
+ if self.use_critic:
+ self.critic.save(step)
+
+ def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
+ sample_folder = os.path.join(self.save_root, 'samples')
+ if not os.path.exists(sample_folder):
+ os.makedirs(sample_folder, exist_ok=True)
+ batch_sample_folder = os.path.join(self.save_root, 'samples_batch')
+
+ batch_targets = None
+ batch_inputs = None
+ if batch is not None and not os.path.exists(batch_sample_folder):
+ os.makedirs(batch_sample_folder, exist_ok=True)
+
+ self.model.eval()
+
+ def process_and_save(img, target_img, save_path):
+ img = img.to(self.device, dtype=self.esrgan_dtype)
+ output = self.model(img)
+ # output = (output / 2 + 0.5).clamp(0, 1)
+ output = output.clamp(0, 1)
+ img = img.clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
+ img = img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
+
+ # convert to pillow image
+ output = Image.fromarray((output * 255).astype(np.uint8))
+ img = Image.fromarray((img * 255).astype(np.uint8))
+
+ if isinstance(target_img, torch.Tensor):
+ # convert to pil
+ target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
+ target_img = Image.fromarray((target_img * 255).astype(np.uint8))
+
+ # upscale to size * self.upscale_sample while maintaining pixels
+ output = output.resize(
+ (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
+ resample=Image.NEAREST
+ )
+ img = img.resize(
+ (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
+ resample=Image.NEAREST
+ )
+
+ width, height = output.size
+
+ # stack input image and decoded image
+ target_image = target_img.resize((width, height))
+ output = output.resize((width, height))
+ img = img.resize((width, height))
+
+ output_img = Image.new('RGB', (width * 3, height))
+
+ output_img.paste(img, (0, 0))
+ output_img.paste(output, (width, 0))
+ output_img.paste(target_image, (width * 2, 0))
+
+ output_img.save(save_path)
+
+ with torch.no_grad():
+ for i, img_url in enumerate(self.sample_sources):
+ img = exif_transpose(Image.open(img_url))
+ img = img.convert('RGB')
+ # crop if not square
+ if img.width != img.height:
+ min_dim = min(img.width, img.height)
+ img = img.crop((0, 0, min_dim, min_dim))
+ # resize
+ img = img.resize((self.resolution * self.zoom, self.resolution * self.zoom), resample=Image.BICUBIC)
+
+ target_image = img
+ # downscale the image input
+ img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC)
+
+ # downscale the image input
+
+ img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
+ img = img
+
+ step_num = ''
+ if step is not None:
+ # zero-pad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+ seconds_since_epoch = int(time.time())
+ # zero-pad 2 digits
+ i_str = str(i).zfill(2)
+ filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
+ process_and_save(img, target_image, os.path.join(sample_folder, filename))
+
+ if batch is not None:
+ batch_targets = batch[0].detach()
+ batch_inputs = batch[1].detach()
+ batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0)
+ batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0)
+
+ for i in range(len(batch_inputs)):
+ if step is not None:
+ # zero-pad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+ seconds_since_epoch = int(time.time())
+ # zero-pad 2 digits
+ i_str = str(i).zfill(2)
+ filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
+ process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
+
+ self.model.train()
+
+ def load_model(self):
+ state_dict = None
+ path_to_load = self.pretrained_path
+ # see if we have a checkpoint in out output to resume from
+ self.print(f"Looking for latest checkpoint in {self.save_root}")
+ files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
+ files += glob.glob(os.path.join(self.save_root, f"{self.job.name}*.pth"))
+ if files and len(files) > 0:
+ latest_file = max(files, key=os.path.getmtime)
+ print(f" - Latest checkpoint is: {latest_file}")
+ path_to_load = latest_file
+ # todo update step and epoch count
+ elif self.pretrained_path is None:
+ self.print(f" - No checkpoint found, starting from scratch")
+ else:
+ self.print(f" - No checkpoint found, loading pretrained model")
+ self.print(f" - path: {path_to_load}")
+
+ if path_to_load is not None:
+ self.print(f" - Loading pretrained checkpoint: {path_to_load}")
+ # if ends with pth then assume pytorch checkpoint
+ if path_to_load.endswith('.pth') or path_to_load.endswith('.pt'):
+ state_dict = torch.load(path_to_load, map_location=self.device)
+ elif path_to_load.endswith('.safetensors'):
+ state_dict_raw = load_file(path_to_load)
+ # make ordered dict as most things need it
+ state_dict = OrderedDict()
+ for key in esrgan_safetensors_keys:
+ state_dict[key] = state_dict_raw[key]
+ else:
+ raise Exception(f"Unknown file extension for checkpoint: {path_to_load}")
+
+ # todo determine architecture from checkpoint
+ self.model = ESRGAN(
+ state_dict
+ ).to(self.device, dtype=self.esrgan_dtype)
+
+ # set the model to training mode
+ self.model.train()
+ self.model.requires_grad_(True)
+
+ def run(self):
+ super().run()
+ self.load_datasets()
+ steps_per_step = (self.critic.num_critic_per_gen + 1)
+
+ max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step)
+ num_epochs = self.epochs
+ if num_epochs is None or num_epochs > max_step_epochs:
+ num_epochs = max_step_epochs
+
+ max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step
+ num_steps = self.max_steps
+ if num_steps is None or num_steps > max_epoch_steps:
+ num_steps = max_epoch_steps
+ self.max_steps = num_steps
+ self.epochs = num_epochs
+ start_step = self.step_num
+ self.first_step = start_step
+
+ self.print(f"Training ESRGAN model:")
+ self.print(f" - Training folder: {self.training_folder}")
+ self.print(f" - Batch size: {self.batch_size}")
+ self.print(f" - Learning rate: {self.learning_rate}")
+ self.print(f" - Epochs: {num_epochs}")
+ self.print(f" - Max steps: {self.max_steps}")
+
+ # load model
+ self.load_model()
+
+ params = self.model.parameters()
+
+ if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
+ self.setup_vgg19()
+ self.vgg_19.requires_grad_(False)
+ self.vgg_19.eval()
+ if self.use_critic:
+ self.critic.setup()
+
+ optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
+ optimizer_params=self.optimizer_params)
+
+ # setup scheduler
+ # todo allow other schedulers
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
+ optimizer,
+ total_iters=num_steps,
+ factor=1,
+ verbose=False
+ )
+
+ # setup tqdm progress bar
+ self.progress_bar = tqdm(
+ total=num_steps,
+ desc='Training ESRGAN',
+ leave=True
+ )
+
+ blank_losses = OrderedDict({
+ "total": [],
+ "style": [],
+ "content": [],
+ "mse": [],
+ "kl": [],
+ "tv": [],
+ "ptn": [],
+ "crD": [],
+ "crG": [],
+ })
+ epoch_losses = copy.deepcopy(blank_losses)
+ log_losses = copy.deepcopy(blank_losses)
+ print("Generating baseline samples")
+ self.sample(step=0)
+ # range start at self.epoch_num go to self.epochs
+ critic_losses = []
+ for epoch in range(self.epoch_num, self.epochs, 1):
+ if self.step_num >= self.max_steps:
+ break
+ flush()
+ for targets, inputs in self.data_loader:
+ if self.step_num >= self.max_steps:
+ break
+ with torch.no_grad():
+ is_critic_only_step = False
+ if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform():
+ is_critic_only_step = True
+
+ targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
+ inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach()
+
+ optimizer.zero_grad()
+ # dont do grads here for critic step
+ do_grad = not is_critic_only_step
+ with torch.set_grad_enabled(do_grad):
+ pred = self.model(inputs)
+
+ pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
+ targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
+ if torch.isnan(pred).any():
+ raise ValueError('pred has nan values')
+ if torch.isnan(targets).any():
+ raise ValueError('targets has nan values')
+
+ # Run through VGG19
+ if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
+ stacked = torch.cat([pred, targets], dim=0)
+ # stacked = (stacked / 2 + 0.5).clamp(0, 1)
+ stacked = stacked.clamp(0, 1)
+ self.vgg_19(stacked)
+ # make sure we dont have nans
+ if torch.isnan(self.vgg19_pool_4.tensor).any():
+ raise ValueError('vgg19_pool_4 has nan values')
+
+ if is_critic_only_step:
+ critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
+ critic_losses.append(critic_d_loss)
+ # don't do generator step
+ continue
+ else:
+ # doing a regular step
+ if len(critic_losses) == 0:
+ critic_d_loss = 0
+ else:
+ critic_d_loss = sum(critic_losses) / len(critic_losses)
+
+ style_loss = self.get_style_loss() * self.style_weight
+ content_loss = self.get_content_loss() * self.content_weight
+
+ mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
+ tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
+ pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
+ if self.use_critic:
+ critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
+ else:
+ critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
+
+ loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
+ # make sure non nan
+ if torch.isnan(loss):
+ raise ValueError('loss is nan')
+
+ # Backward pass and optimization
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
+ optimizer.step()
+ scheduler.step()
+
+ # update progress bar
+ loss_value = loss.item()
+ # get exponent like 3.54e-4
+ loss_string = f"loss: {loss_value:.2e}"
+ if self.content_weight > 0:
+ loss_string += f" cnt: {content_loss.item():.2e}"
+ if self.style_weight > 0:
+ loss_string += f" sty: {style_loss.item():.2e}"
+ if self.mse_weight > 0:
+ loss_string += f" mse: {mse_loss.item():.2e}"
+ if self.tv_weight > 0:
+ loss_string += f" tv: {tv_loss.item():.2e}"
+ if self.pattern_weight > 0:
+ loss_string += f" ptn: {pattern_loss.item():.2e}"
+ if self.use_critic and self.critic_weight > 0:
+ loss_string += f" crG: {critic_gen_loss.item():.2e}"
+ if self.use_critic:
+ loss_string += f" crD: {critic_d_loss:.2e}"
+
+ if self.optimizer_type.startswith('dadaptation') or self.optimizer_type.startswith('prodigy'):
+ learning_rate = (
+ optimizer.param_groups[0]["d"] *
+ optimizer.param_groups[0]["lr"]
+ )
+ else:
+ learning_rate = optimizer.param_groups[0]['lr']
+
+ lr_critic_string = ''
+ if self.use_critic:
+ lr_critic = self.critic.get_lr()
+ lr_critic_string = f" lrC: {lr_critic:.1e}"
+
+ self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
+ self.progress_bar.set_description(f"E: {epoch}")
+ self.progress_bar.update(1)
+
+ epoch_losses["total"].append(loss_value)
+ epoch_losses["style"].append(style_loss.item())
+ epoch_losses["content"].append(content_loss.item())
+ epoch_losses["mse"].append(mse_loss.item())
+ epoch_losses["tv"].append(tv_loss.item())
+ epoch_losses["ptn"].append(pattern_loss.item())
+ epoch_losses["crG"].append(critic_gen_loss.item())
+ epoch_losses["crD"].append(critic_d_loss)
+
+ log_losses["total"].append(loss_value)
+ log_losses["style"].append(style_loss.item())
+ log_losses["content"].append(content_loss.item())
+ log_losses["mse"].append(mse_loss.item())
+ log_losses["tv"].append(tv_loss.item())
+ log_losses["ptn"].append(pattern_loss.item())
+ log_losses["crG"].append(critic_gen_loss.item())
+ log_losses["crD"].append(critic_d_loss)
+
+ # don't do on first step
+ if self.step_num != start_step:
+ if self.sample_every and self.step_num % self.sample_every == 0:
+ # print above the progress bar
+ self.print(f"Sampling at step {self.step_num}")
+ self.sample(self.step_num, batch=[targets, inputs])
+
+ if self.save_every and self.step_num % self.save_every == 0:
+ # print above the progress bar
+ self.print(f"Saving at step {self.step_num}")
+ self.save(self.step_num)
+
+ if self.log_every and self.step_num % self.log_every == 0:
+ # log to tensorboard
+ if self.writer is not None:
+ # get avg loss
+ for key in log_losses:
+ log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
+ # if log_losses[key] > 0:
+ self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
+ # reset log losses
+ log_losses = copy.deepcopy(blank_losses)
+
+ self.step_num += 1
+ # end epoch
+ if self.writer is not None:
+ eps = 1e-6
+ # get avg loss
+ for key in epoch_losses:
+ epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
+ if epoch_losses[key] > 0:
+ self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
+ # reset epoch losses
+ epoch_losses = copy.deepcopy(blank_losses)
+
+ self.save()
diff --git a/jobs/process/TrainFineTuneProcess.py b/jobs/process/TrainFineTuneProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..a13a7cf640ad2a695d2f330a8cb4636985593376
--- /dev/null
+++ b/jobs/process/TrainFineTuneProcess.py
@@ -0,0 +1,13 @@
+from collections import OrderedDict
+from jobs import TrainJob
+from jobs.process import BaseTrainProcess
+
+
+class TrainFineTuneProcess(BaseTrainProcess):
+ def __init__(self,process_id: int, job: TrainJob, config: OrderedDict):
+ super().__init__(process_id, job, config)
+
+ def run(self):
+ # implement in child class
+ # be sure to call super().run() first
+ pass
diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc2dc3398a5c29edcaa386e48117644a395db677
--- /dev/null
+++ b/jobs/process/TrainSDRescaleProcess.py
@@ -0,0 +1,277 @@
+import glob
+import os
+from collections import OrderedDict
+import random
+from typing import Optional, List
+
+from safetensors.torch import save_file, load_file
+from tqdm import tqdm
+
+from toolkit.layers import ReductionKernel
+from toolkit.stable_diffusion_model import PromptEmbeds
+from toolkit.train_tools import get_torch_dtype, apply_noise_offset
+import gc
+from toolkit import train_tools
+
+import torch
+from leco import train_util, model_util
+from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
+
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+class RescaleConfig:
+ def __init__(
+ self,
+ **kwargs
+ ):
+ self.from_resolution = kwargs.get('from_resolution', 512)
+ self.scale = kwargs.get('scale', 0.5)
+ self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None)
+ self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000)
+ self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
+ self.prompt_dropout = kwargs.get('prompt_dropout', 0.1)
+
+
+class PromptEmbedsCache:
+ prompts: dict[str, PromptEmbeds] = {}
+
+ def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
+ self.prompts[__name] = __value
+
+ def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
+ if __name in self.prompts:
+ return self.prompts[__name]
+ else:
+ return None
+
+
+class TrainSDRescaleProcess(BaseSDTrainProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ # pass our custom pipeline to super so it sets it up
+ super().__init__(process_id, job, config)
+ self.step_num = 0
+ self.start_step = 0
+ self.device = self.get_conf('device', self.job.device)
+ self.device_torch = torch.device(self.device)
+ self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
+ self.reduce_size_fn = ReductionKernel(
+ in_channels=4,
+ kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution),
+ dtype=get_torch_dtype(self.train_config.dtype),
+ device=self.device_torch,
+ )
+
+ self.latent_paths: List[str] = []
+ self.empty_embedding: PromptEmbeds = None
+
+ def before_model_load(self):
+ pass
+
+ def get_latent_tensors(self):
+ dtype = get_torch_dtype(self.train_config.dtype)
+
+ num_to_generate = 0
+ # check if dir exists
+ if not os.path.exists(self.rescale_config.latent_tensor_dir):
+ os.makedirs(self.rescale_config.latent_tensor_dir)
+ num_to_generate = self.rescale_config.num_latent_tensors
+ else:
+ # find existing
+ current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors"))
+ num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list)
+ self.latent_paths = current_tensor_list
+
+ if num_to_generate > 0:
+ print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors")
+
+ # unload other model
+ self.sd.unet.to('cpu')
+
+ # load aux network
+ self.sd_parent = StableDiffusion(
+ self.device_torch,
+ model_config=self.model_config,
+ dtype=self.train_config.dtype,
+ )
+ self.sd_parent.load_model()
+ self.sd_parent.unet.to(self.device_torch, dtype=dtype)
+ # we dont need text encoder for this
+
+ del self.sd_parent.text_encoder
+ del self.sd_parent.tokenizer
+
+ self.sd_parent.unet.eval()
+ self.sd_parent.unet.requires_grad_(False)
+
+ # save current seed state for training
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
+
+ text_embeddings = train_tools.concat_prompt_embeddings(
+ self.empty_embedding, # unconditional (negative prompt)
+ self.empty_embedding, # conditional (positive prompt)
+ self.train_config.batch_size,
+ )
+ torch.set_default_device(self.device_torch)
+
+ for i in tqdm(range(num_to_generate)):
+ dtype = get_torch_dtype(self.train_config.dtype)
+ # get a random seed
+ seed = torch.randint(0, 2 ** 32, (1,)).item()
+ # zero pad seed string to max length
+ seed_string = str(seed).zfill(10)
+ # set seed
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ # # ger a random number of steps
+ timesteps_to = self.train_config.max_denoising_steps
+
+ # set the scheduler to the number of steps
+ self.sd.noise_scheduler.set_timesteps(
+ timesteps_to, device=self.device_torch
+ )
+
+ noise = self.sd.get_latent_noise(
+ pixel_height=self.rescale_config.from_resolution,
+ pixel_width=self.rescale_config.from_resolution,
+ batch_size=self.train_config.batch_size,
+ noise_offset=self.train_config.noise_offset,
+ ).to(self.device_torch, dtype=dtype)
+
+ # get latents
+ latents = noise * self.sd.noise_scheduler.init_noise_sigma
+ latents = latents.to(self.device_torch, dtype=dtype)
+
+ # get random guidance scale from 1.0 to 10.0 (CFG)
+ guidance_scale = torch.rand(1).item() * 9.0 + 1.0
+
+ # do a timestep of 1
+ timestep = 1
+
+ noise_pred_target = self.sd_parent.predict_noise(
+ latents,
+ text_embeddings=text_embeddings,
+ timestep=timestep,
+ guidance_scale=guidance_scale
+ )
+
+ # build state dict
+ state_dict = OrderedDict()
+ state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16)
+ state_dict['latents'] = latents.to('cpu', dtype=torch.float16)
+ state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16)
+ state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16)
+ state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16)
+ state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow
+
+ file_name = f"{seed_string}_{i}.safetensors"
+ file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name)
+ save_file(state_dict, file_path)
+ self.latent_paths.append(file_path)
+
+ print("Removing parent model")
+ # delete parent
+ del self.sd_parent
+ flush()
+
+ torch.set_rng_state(rng_state)
+ if cuda_rng_state is not None:
+ torch.cuda.set_rng_state(cuda_rng_state)
+ self.sd.unet.to(self.device_torch, dtype=dtype)
+
+ def hook_before_train_loop(self):
+ # encode our empty prompt
+ self.empty_embedding = self.sd.encode_prompt("")
+ self.empty_embedding = self.empty_embedding.to(self.device_torch,
+ dtype=get_torch_dtype(self.train_config.dtype))
+
+ # Move train model encoder to cpu
+ if isinstance(self.sd.text_encoder, list):
+ for encoder in self.sd.text_encoder:
+ encoder.to('cpu')
+ encoder.eval()
+ encoder.requires_grad_(False)
+ else:
+ self.sd.text_encoder.to('cpu')
+ self.sd.text_encoder.eval()
+ self.sd.text_encoder.requires_grad_(False)
+
+ # self.sd.unet.to('cpu')
+ flush()
+
+ self.get_latent_tensors()
+
+ flush()
+ # end hook_before_train_loop
+
+ def hook_train_loop(self, batch):
+ dtype = get_torch_dtype(self.train_config.dtype)
+
+ loss_function = torch.nn.MSELoss()
+
+ # train it
+ # Begin gradient accumulation
+ self.sd.unet.train()
+ self.sd.unet.requires_grad_(True)
+ self.sd.unet.to(self.device_torch, dtype=dtype)
+
+ with torch.no_grad():
+ self.optimizer.zero_grad()
+
+ # pick random latent tensor
+ latent_path = random.choice(self.latent_paths)
+ latent_tensor = load_file(latent_path)
+
+ noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype)
+ latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype)
+ guidance_scale = (latent_tensor['guidance_scale']).item()
+ timestep = int((latent_tensor['timestep']).item())
+ timesteps_to = int((latent_tensor['timesteps_to']).item())
+ # seed = int((latent_tensor['seed']).item())
+
+ text_embeddings = train_tools.concat_prompt_embeddings(
+ self.empty_embedding, # unconditional (negative prompt)
+ self.empty_embedding, # conditional (positive prompt)
+ self.train_config.batch_size,
+ )
+ self.sd.noise_scheduler.set_timesteps(
+ timesteps_to, device=self.device_torch
+ )
+
+ denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
+
+ # get the reduced latents
+ # reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
+ denoised_target = self.reduce_size_fn(denoised_target.detach())
+ reduced_latents = self.reduce_size_fn(latents.detach())
+
+ denoised_target.requires_grad = False
+ self.optimizer.zero_grad()
+ noise_pred_train = self.sd.predict_noise(
+ reduced_latents,
+ text_embeddings=text_embeddings,
+ timestep=timestep,
+ guidance_scale=guidance_scale
+ )
+ denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample
+ loss = loss_function(denoised_pred, denoised_target)
+ loss_float = loss.item()
+ loss.backward()
+ self.optimizer.step()
+ self.lr_scheduler.step()
+ self.optimizer.zero_grad()
+
+ flush()
+
+ loss_dict = OrderedDict(
+ {'loss': loss_float},
+ )
+
+ return loss_dict
+ # end hook_train_loop
diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..88b9d104e973e00481734182ef55b77767c4be88
--- /dev/null
+++ b/jobs/process/TrainSliderProcess.py
@@ -0,0 +1,694 @@
+import copy
+import os
+import random
+from collections import OrderedDict
+from typing import Union
+
+from PIL import Image
+from diffusers import T2IAdapter
+from torchvision.transforms import transforms
+from tqdm import tqdm
+
+from toolkit.basic import value_map
+from toolkit.config_modules import SliderConfig
+from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
+from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
+from toolkit.train_tools import get_torch_dtype, apply_snr_weight, apply_learnable_snr_gos
+import gc
+from toolkit import train_tools
+from toolkit.prompt_utils import \
+ EncodedPromptPair, ACTION_TYPES_SLIDER, \
+ EncodedAnchor, concat_prompt_pairs, \
+ concat_anchors, PromptEmbedsCache, encode_prompts_to_cache, build_prompt_pair_batch_from_cache, split_anchors, \
+ split_prompt_pairs
+
+import torch
+from .BaseSDTrainProcess import BaseSDTrainProcess
+
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+adapter_transforms = transforms.Compose([
+ transforms.ToTensor(),
+])
+
+
+class TrainSliderProcess(BaseSDTrainProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.prompt_txt_list = None
+ self.step_num = 0
+ self.start_step = 0
+ self.device = self.get_conf('device', self.job.device)
+ self.device_torch = torch.device(self.device)
+ self.slider_config = SliderConfig(**self.get_conf('slider', {}))
+ self.prompt_cache = PromptEmbedsCache()
+ self.prompt_pairs: list[EncodedPromptPair] = []
+ self.anchor_pairs: list[EncodedAnchor] = []
+ # keep track of prompt chunk size
+ self.prompt_chunk_size = 1
+
+ # check if we have more targets than steps
+ # this can happen because of permutation son shuffling
+ if len(self.slider_config.targets) > self.train_config.steps:
+ # trim targets
+ self.slider_config.targets = self.slider_config.targets[:self.train_config.steps]
+
+ # get presets
+ self.eval_slider_device_state = get_train_sd_device_state_preset(
+ self.device_torch,
+ train_unet=False,
+ train_text_encoder=False,
+ cached_latents=self.is_latents_cached,
+ train_lora=False,
+ train_adapter=False,
+ train_embedding=False,
+ )
+
+ self.train_slider_device_state = get_train_sd_device_state_preset(
+ self.device_torch,
+ train_unet=self.train_config.train_unet,
+ train_text_encoder=False,
+ cached_latents=self.is_latents_cached,
+ train_lora=True,
+ train_adapter=False,
+ train_embedding=False,
+ )
+
+ def before_model_load(self):
+ pass
+
+ def hook_before_train_loop(self):
+
+ # read line by line from file
+ if self.slider_config.prompt_file:
+ self.print(f"Loading prompt file from {self.slider_config.prompt_file}")
+ with open(self.slider_config.prompt_file, 'r', encoding='utf-8') as f:
+ self.prompt_txt_list = f.readlines()
+ # clean empty lines
+ self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
+
+ self.print(f"Found {len(self.prompt_txt_list)} prompts.")
+
+ if not self.slider_config.prompt_tensors:
+ print(f"Prompt tensors not found. Building prompt tensors for {self.train_config.steps} steps.")
+ # shuffle
+ random.shuffle(self.prompt_txt_list)
+ # trim to max steps
+ self.prompt_txt_list = self.prompt_txt_list[:self.train_config.steps]
+ # trim list to our max steps
+
+ cache = PromptEmbedsCache()
+ print(f"Building prompt cache")
+
+ # get encoded latents for our prompts
+ with torch.no_grad():
+ # list of neutrals. Can come from file or be empty
+ neutral_list = self.prompt_txt_list if self.prompt_txt_list is not None else [""]
+
+ # build the prompts to cache
+ prompts_to_cache = []
+ for neutral in neutral_list:
+ for target in self.slider_config.targets:
+ prompt_list = [
+ f"{target.target_class}", # target_class
+ f"{target.target_class} {neutral}", # target_class with neutral
+ f"{target.positive}", # positive_target
+ f"{target.positive} {neutral}", # positive_target with neutral
+ f"{target.negative}", # negative_target
+ f"{target.negative} {neutral}", # negative_target with neutral
+ f"{neutral}", # neutral
+ f"{target.positive} {target.negative}", # both targets
+ f"{target.negative} {target.positive}", # both targets reverse
+ ]
+ prompts_to_cache += prompt_list
+
+ # remove duplicates
+ prompts_to_cache = list(dict.fromkeys(prompts_to_cache))
+
+ # trim to max steps if max steps is lower than prompt count
+ # todo, this can break if we have more targets than steps, should be fixed, by reducing permuations, but could stil happen with low steps
+ # prompts_to_cache = prompts_to_cache[:self.train_config.steps]
+
+ # encode them
+ cache = encode_prompts_to_cache(
+ prompt_list=prompts_to_cache,
+ sd=self.sd,
+ cache=cache,
+ prompt_tensor_file=self.slider_config.prompt_tensors
+ )
+
+ prompt_pairs = []
+ prompt_batches = []
+ for neutral in tqdm(neutral_list, desc="Building Prompt Pairs", leave=False):
+ for target in self.slider_config.targets:
+ prompt_pair_batch = build_prompt_pair_batch_from_cache(
+ cache=cache,
+ target=target,
+ neutral=neutral,
+
+ )
+ if self.slider_config.batch_full_slide:
+ # concat the prompt pairs
+ # this allows us to run the entire 4 part process in one shot (for slider)
+ self.prompt_chunk_size = 4
+ concat_prompt_pair_batch = concat_prompt_pairs(prompt_pair_batch).to('cpu')
+ prompt_pairs += [concat_prompt_pair_batch]
+ else:
+ self.prompt_chunk_size = 1
+ # do them one at a time (probably not necessary after new optimizations)
+ prompt_pairs += [x.to('cpu') for x in prompt_pair_batch]
+
+ # setup anchors
+ anchor_pairs = []
+ for anchor in self.slider_config.anchors:
+ # build the cache
+ for prompt in [
+ anchor.prompt,
+ anchor.neg_prompt # empty neutral
+ ]:
+ if cache[prompt] == None:
+ cache[prompt] = self.sd.encode_prompt(prompt)
+
+ anchor_batch = []
+ # we get the prompt pair multiplier from first prompt pair
+ # since they are all the same. We need to match their network polarity
+ prompt_pair_multipliers = prompt_pairs[0].multiplier_list
+ for prompt_multiplier in prompt_pair_multipliers:
+ # match the network multiplier polarity
+ anchor_scalar = 1.0 if prompt_multiplier > 0 else -1.0
+ anchor_batch += [
+ EncodedAnchor(
+ prompt=cache[anchor.prompt],
+ neg_prompt=cache[anchor.neg_prompt],
+ multiplier=anchor.multiplier * anchor_scalar
+ )
+ ]
+
+ anchor_pairs += [
+ concat_anchors(anchor_batch).to('cpu')
+ ]
+ if len(anchor_pairs) > 0:
+ self.anchor_pairs = anchor_pairs
+
+ # move to cpu to save vram
+ # We don't need text encoder anymore, but keep it on cpu for sampling
+ # if text encoder is list
+ if isinstance(self.sd.text_encoder, list):
+ for encoder in self.sd.text_encoder:
+ encoder.to("cpu")
+ else:
+ self.sd.text_encoder.to("cpu")
+ self.prompt_cache = cache
+ self.prompt_pairs = prompt_pairs
+ # self.anchor_pairs = anchor_pairs
+ flush()
+ if self.data_loader is not None:
+ # we will have images, prep the vae
+ self.sd.vae.eval()
+ self.sd.vae.to(self.device_torch)
+ # end hook_before_train_loop
+
+ def before_dataset_load(self):
+ if self.slider_config.use_adapter == 'depth':
+ print(f"Loading T2I Adapter for depth")
+ # called before LoRA network is loaded but after model is loaded
+ # attach the adapter here so it is there before we load the network
+ adapter_path = 'TencentARC/t2iadapter_depth_sd15v2'
+ if self.model_config.is_xl:
+ adapter_path = 'TencentARC/t2i-adapter-depth-midas-sdxl-1.0'
+
+ print(f"Loading T2I Adapter from {adapter_path}")
+
+ # dont name this adapter since we are not training it
+ self.t2i_adapter = T2IAdapter.from_pretrained(
+ adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype), varient="fp16"
+ ).to(self.device_torch)
+ self.t2i_adapter.eval()
+ self.t2i_adapter.requires_grad_(False)
+ flush()
+
+ @torch.no_grad()
+ def get_adapter_images(self, batch: Union[None, 'DataLoaderBatchDTO']):
+
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ adapter_folder_path = self.slider_config.adapter_img_dir
+ adapter_images = []
+ # loop through images
+ for file_item in batch.file_items:
+ img_path = file_item.path
+ file_name_no_ext = os.path.basename(img_path).split('.')[0]
+ # find the image
+ for ext in img_ext_list:
+ if os.path.exists(os.path.join(adapter_folder_path, file_name_no_ext + ext)):
+ adapter_images.append(os.path.join(adapter_folder_path, file_name_no_ext + ext))
+ break
+ width, height = batch.file_items[0].crop_width, batch.file_items[0].crop_height
+ adapter_tensors = []
+ # load images with torch transforms
+ for idx, adapter_image in enumerate(adapter_images):
+ # we need to centrally crop the largest dimension of the image to match the batch shape after scaling
+ # to the smallest dimension
+ img: Image.Image = Image.open(adapter_image)
+ if img.width > img.height:
+ # scale down so height is the same as batch
+ new_height = height
+ new_width = int(img.width * (height / img.height))
+ else:
+ new_width = width
+ new_height = int(img.height * (width / img.width))
+
+ img = img.resize((new_width, new_height))
+ crop_fn = transforms.CenterCrop((height, width))
+ # crop the center to match batch
+ img = crop_fn(img)
+ img = adapter_transforms(img)
+ adapter_tensors.append(img)
+
+ # stack them
+ adapter_tensors = torch.stack(adapter_tensors).to(
+ self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)
+ )
+ return adapter_tensors
+
+ def hook_train_loop(self, batch: Union['DataLoaderBatchDTO', None]):
+ # set to eval mode
+ self.sd.set_device_state(self.eval_slider_device_state)
+ with torch.no_grad():
+ dtype = get_torch_dtype(self.train_config.dtype)
+
+ # get a random pair
+ prompt_pair: EncodedPromptPair = self.prompt_pairs[
+ torch.randint(0, len(self.prompt_pairs), (1,)).item()
+ ]
+ # move to device and dtype
+ prompt_pair.to(self.device_torch, dtype=dtype)
+
+ # get a random resolution
+ height, width = self.slider_config.resolutions[
+ torch.randint(0, len(self.slider_config.resolutions), (1,)).item()
+ ]
+ if self.train_config.gradient_checkpointing:
+ # may get disabled elsewhere
+ self.sd.unet.enable_gradient_checkpointing()
+
+ noise_scheduler = self.sd.noise_scheduler
+ optimizer = self.optimizer
+ lr_scheduler = self.lr_scheduler
+
+ loss_function = torch.nn.MSELoss()
+
+ pred_kwargs = {}
+
+ def get_noise_pred(neg, pos, gs, cts, dn):
+ down_kwargs = copy.deepcopy(pred_kwargs)
+ if 'down_block_additional_residuals' in down_kwargs:
+ dbr_batch_size = down_kwargs['down_block_additional_residuals'][0].shape[0]
+ if dbr_batch_size != dn.shape[0]:
+ amount_to_add = int(dn.shape[0] * 2 / dbr_batch_size)
+ down_kwargs['down_block_additional_residuals'] = [
+ torch.cat([sample.clone()] * amount_to_add) for sample in
+ down_kwargs['down_block_additional_residuals']
+ ]
+ return self.sd.predict_noise(
+ latents=dn,
+ text_embeddings=train_tools.concat_prompt_embeddings(
+ neg, # negative prompt
+ pos, # positive prompt
+ self.train_config.batch_size,
+ ),
+ timestep=cts,
+ guidance_scale=gs,
+ **down_kwargs
+ )
+
+ with torch.no_grad():
+ adapter_images = None
+ self.sd.unet.eval()
+
+ # for a complete slider, the batch size is 4 to begin with now
+ true_batch_size = prompt_pair.target_class.text_embeds.shape[0] * self.train_config.batch_size
+ from_batch = False
+ if batch is not None:
+ # traing from a batch of images, not generating ourselves
+ from_batch = True
+ noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
+ if self.slider_config.adapter_img_dir is not None:
+ adapter_images = self.get_adapter_images(batch)
+ adapter_strength_min = 0.9
+ adapter_strength_max = 1.0
+
+ def rand_strength(sample):
+ adapter_conditioning_scale = torch.rand(
+ (1,), device=self.device_torch, dtype=dtype
+ )
+
+ adapter_conditioning_scale = value_map(
+ adapter_conditioning_scale,
+ 0.0,
+ 1.0,
+ adapter_strength_min,
+ adapter_strength_max
+ )
+ return sample.to(self.device_torch, dtype=dtype).detach() * adapter_conditioning_scale
+
+ down_block_additional_residuals = self.t2i_adapter(adapter_images)
+ down_block_additional_residuals = [
+ rand_strength(sample) for sample in down_block_additional_residuals
+ ]
+ pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
+
+ denoised_latents = torch.cat([noisy_latents] * self.prompt_chunk_size, dim=0)
+ current_timestep = timesteps
+ else:
+
+ self.sd.noise_scheduler.set_timesteps(
+ self.train_config.max_denoising_steps, device=self.device_torch
+ )
+
+ # ger a random number of steps
+ timesteps_to = torch.randint(
+ 1, self.train_config.max_denoising_steps - 1, (1,)
+ ).item()
+
+ # get noise
+ noise = self.sd.get_latent_noise(
+ pixel_height=height,
+ pixel_width=width,
+ batch_size=true_batch_size,
+ noise_offset=self.train_config.noise_offset,
+ ).to(self.device_torch, dtype=dtype)
+
+ # get latents
+ latents = noise * self.sd.noise_scheduler.init_noise_sigma
+ latents = latents.to(self.device_torch, dtype=dtype)
+
+ assert not self.network.is_active
+ self.sd.unet.eval()
+ # pass the multiplier list to the network
+ # double up since we are doing cfg
+ self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
+ denoised_latents = self.sd.diffuse_some_steps(
+ latents, # pass simple noise latents
+ train_tools.concat_prompt_embeddings(
+ prompt_pair.positive_target, # unconditional
+ prompt_pair.target_class, # target
+ self.train_config.batch_size,
+ ),
+ start_timesteps=0,
+ total_timesteps=timesteps_to,
+ guidance_scale=3,
+ )
+
+
+ noise_scheduler.set_timesteps(1000)
+
+ current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
+ current_timestep = noise_scheduler.timesteps[current_timestep_index]
+
+ # split the latents into out prompt pair chunks
+ denoised_latent_chunks = torch.chunk(denoised_latents, self.prompt_chunk_size, dim=0)
+ denoised_latent_chunks = [x.detach() for x in denoised_latent_chunks]
+
+ # flush() # 4.2GB to 3GB on 512x512
+ mask_multiplier = torch.ones((denoised_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
+ has_mask = False
+ if batch and batch.mask_tensor is not None:
+ with self.timer('get_mask_multiplier'):
+ # upsampling no supported for bfloat16
+ mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach()
+ # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height)
+ mask_multiplier = torch.nn.functional.interpolate(
+ mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3])
+ )
+ # expand to match latents
+ mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
+ mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
+ has_mask = True
+
+ if has_mask:
+ unmasked_target = get_noise_pred(
+ prompt_pair.positive_target, # negative prompt
+ prompt_pair.target_class, # positive prompt
+ 1,
+ current_timestep,
+ denoised_latents
+ )
+ unmasked_target = unmasked_target.detach()
+ unmasked_target.requires_grad = False
+ else:
+ unmasked_target = None
+
+ # 4.20 GB RAM for 512x512
+ positive_latents = get_noise_pred(
+ prompt_pair.positive_target, # negative prompt
+ prompt_pair.negative_target, # positive prompt
+ 1,
+ current_timestep,
+ denoised_latents
+ )
+ positive_latents = positive_latents.detach()
+ positive_latents.requires_grad = False
+
+ neutral_latents = get_noise_pred(
+ prompt_pair.positive_target, # negative prompt
+ prompt_pair.empty_prompt, # positive prompt (normally neutral
+ 1,
+ current_timestep,
+ denoised_latents
+ )
+ neutral_latents = neutral_latents.detach()
+ neutral_latents.requires_grad = False
+
+ unconditional_latents = get_noise_pred(
+ prompt_pair.positive_target, # negative prompt
+ prompt_pair.positive_target, # positive prompt
+ 1,
+ current_timestep,
+ denoised_latents
+ )
+ unconditional_latents = unconditional_latents.detach()
+ unconditional_latents.requires_grad = False
+
+ denoised_latents = denoised_latents.detach()
+
+ self.sd.set_device_state(self.train_slider_device_state)
+ self.sd.unet.train()
+ # start accumulating gradients
+ self.optimizer.zero_grad(set_to_none=True)
+
+ anchor_loss_float = None
+ if len(self.anchor_pairs) > 0:
+ with torch.no_grad():
+ # get a random anchor pair
+ anchor: EncodedAnchor = self.anchor_pairs[
+ torch.randint(0, len(self.anchor_pairs), (1,)).item()
+ ]
+ anchor.to(self.device_torch, dtype=dtype)
+
+ # first we get the target prediction without network active
+ anchor_target_noise = get_noise_pred(
+ anchor.neg_prompt, anchor.prompt, 1, current_timestep, denoised_latents
+ # ).to("cpu", dtype=torch.float32)
+ ).requires_grad_(False)
+
+ # to save vram, we will run these through separately while tracking grads
+ # otherwise it consumes a ton of vram and this isn't our speed bottleneck
+ anchor_chunks = split_anchors(anchor, self.prompt_chunk_size)
+ anchor_target_noise_chunks = torch.chunk(anchor_target_noise, self.prompt_chunk_size, dim=0)
+ assert len(anchor_chunks) == len(denoised_latent_chunks)
+
+ # 4.32 GB RAM for 512x512
+ with self.network:
+ assert self.network.is_active
+ anchor_float_losses = []
+ for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
+ anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
+ ):
+ self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
+
+ anchor_pred_noise = get_noise_pred(
+ anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
+ )
+ # 9.42 GB RAM for 512x512 -> 4.20 GB RAM for 512x512 with new grad_checkpointing
+ anchor_loss = loss_function(
+ anchor_target_noise_chunk,
+ anchor_pred_noise,
+ )
+ anchor_float_losses.append(anchor_loss.item())
+ # compute anchor loss gradients
+ # we will accumulate them later
+ # this saves a ton of memory doing them separately
+ anchor_loss.backward()
+ del anchor_pred_noise
+ del anchor_target_noise_chunk
+ del anchor_loss
+ flush()
+
+ anchor_loss_float = sum(anchor_float_losses) / len(anchor_float_losses)
+ del anchor_chunks
+ del anchor_target_noise_chunks
+ del anchor_target_noise
+ # move anchor back to cpu
+ anchor.to("cpu")
+
+ with torch.no_grad():
+ if self.slider_config.low_ram:
+ prompt_pair_chunks = split_prompt_pairs(prompt_pair.detach(), self.prompt_chunk_size)
+ denoised_latent_chunks = denoised_latent_chunks # just to have it in one place
+ positive_latents_chunks = torch.chunk(positive_latents.detach(), self.prompt_chunk_size, dim=0)
+ neutral_latents_chunks = torch.chunk(neutral_latents.detach(), self.prompt_chunk_size, dim=0)
+ unconditional_latents_chunks = torch.chunk(
+ unconditional_latents.detach(),
+ self.prompt_chunk_size,
+ dim=0
+ )
+ mask_multiplier_chunks = torch.chunk(mask_multiplier, self.prompt_chunk_size, dim=0)
+ if unmasked_target is not None:
+ unmasked_target_chunks = torch.chunk(unmasked_target, self.prompt_chunk_size, dim=0)
+ else:
+ unmasked_target_chunks = [None for _ in range(self.prompt_chunk_size)]
+ else:
+ # run through in one instance
+ prompt_pair_chunks = [prompt_pair.detach()]
+ denoised_latent_chunks = [torch.cat(denoised_latent_chunks, dim=0).detach()]
+ positive_latents_chunks = [positive_latents.detach()]
+ neutral_latents_chunks = [neutral_latents.detach()]
+ unconditional_latents_chunks = [unconditional_latents.detach()]
+ mask_multiplier_chunks = [mask_multiplier]
+ unmasked_target_chunks = [unmasked_target]
+
+ # flush()
+ assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
+ # 3.28 GB RAM for 512x512
+ with self.network:
+ assert self.network.is_active
+ loss_list = []
+ for prompt_pair_chunk, \
+ denoised_latent_chunk, \
+ positive_latents_chunk, \
+ neutral_latents_chunk, \
+ unconditional_latents_chunk, \
+ mask_multiplier_chunk, \
+ unmasked_target_chunk \
+ in zip(
+ prompt_pair_chunks,
+ denoised_latent_chunks,
+ positive_latents_chunks,
+ neutral_latents_chunks,
+ unconditional_latents_chunks,
+ mask_multiplier_chunks,
+ unmasked_target_chunks
+ ):
+ self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
+ target_latents = get_noise_pred(
+ prompt_pair_chunk.positive_target,
+ prompt_pair_chunk.target_class,
+ 1,
+ current_timestep,
+ denoised_latent_chunk
+ )
+
+ guidance_scale = 1.0
+
+ offset = guidance_scale * (positive_latents_chunk - unconditional_latents_chunk)
+
+ # make offset multiplier based on actions
+ offset_multiplier_list = []
+ for action in prompt_pair_chunk.action_list:
+ if action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE:
+ offset_multiplier_list += [-1.0]
+ elif action == ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE:
+ offset_multiplier_list += [1.0]
+
+ offset_multiplier = torch.tensor(offset_multiplier_list).to(offset.device, dtype=offset.dtype)
+ # make offset multiplier match rank of offset
+ offset_multiplier = offset_multiplier.view(offset.shape[0], 1, 1, 1)
+ offset *= offset_multiplier
+
+ offset_neutral = neutral_latents_chunk
+ # offsets are already adjusted on a per-batch basis
+ offset_neutral += offset
+ offset_neutral = offset_neutral.detach().requires_grad_(False)
+
+ # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
+ loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
+
+ # do inverted mask to preserve non masked
+ if has_mask and unmasked_target_chunk is not None:
+ loss = loss * mask_multiplier_chunk
+ # match the mask unmasked_target_chunk
+ mask_target_loss = torch.nn.functional.mse_loss(
+ target_latents.float(),
+ unmasked_target_chunk.float(),
+ reduction="none"
+ )
+ mask_target_loss = mask_target_loss * (1.0 - mask_multiplier_chunk)
+ loss += mask_target_loss
+
+ loss = loss.mean([1, 2, 3])
+
+ if self.train_config.learnable_snr_gos:
+ if from_batch:
+ # match batch size
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
+ self.train_config.min_snr_gamma)
+ else:
+ # match batch size
+ timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
+ # add snr_gamma
+ loss = apply_learnable_snr_gos(loss, timesteps_index_list, self.snr_gos)
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
+ if from_batch:
+ # match batch size
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler,
+ self.train_config.min_snr_gamma)
+ else:
+ # match batch size
+ timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
+ # add min_snr_gamma
+ loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler,
+ self.train_config.min_snr_gamma)
+
+
+ loss = loss.mean() * prompt_pair_chunk.weight
+
+ loss.backward()
+ loss_list.append(loss.item())
+ del target_latents
+ del offset_neutral
+ del loss
+ # flush()
+
+ optimizer.step()
+ lr_scheduler.step()
+
+ loss_float = sum(loss_list) / len(loss_list)
+ if anchor_loss_float is not None:
+ loss_float += anchor_loss_float
+
+ del (
+ positive_latents,
+ neutral_latents,
+ unconditional_latents,
+ # latents
+ )
+ # move back to cpu
+ prompt_pair.to("cpu")
+ # flush()
+
+ # reset network
+ self.network.multiplier = 1.0
+
+ loss_dict = OrderedDict(
+ {'loss': loss_float},
+ )
+ if anchor_loss_float is not None:
+ loss_dict['sl_l'] = loss_float
+ loss_dict['an_l'] = anchor_loss_float
+
+ return loss_dict
+ # end hook_train_loop
diff --git a/jobs/process/TrainSliderProcessOld.py b/jobs/process/TrainSliderProcessOld.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c25393a3503a99532676b153794fbe89c609160
--- /dev/null
+++ b/jobs/process/TrainSliderProcessOld.py
@@ -0,0 +1,408 @@
+# ref:
+# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
+import time
+from collections import OrderedDict
+import os
+from typing import Optional
+
+from toolkit.config_modules import SliderConfig
+from toolkit.paths import REPOS_ROOT
+import sys
+
+from toolkit.stable_diffusion_model import PromptEmbeds
+
+sys.path.append(REPOS_ROOT)
+sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
+from toolkit.train_tools import get_torch_dtype, apply_noise_offset
+import gc
+from toolkit import train_tools
+
+import torch
+from leco import train_util, model_util
+from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
+
+
+class ACTION_TYPES_SLIDER:
+ ERASE_NEGATIVE = 0
+ ENHANCE_NEGATIVE = 1
+
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+class EncodedPromptPair:
+ def __init__(
+ self,
+ target_class,
+ positive,
+ negative,
+ neutral,
+ width=512,
+ height=512,
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ multiplier=1.0,
+ weight=1.0
+ ):
+ self.target_class = target_class
+ self.positive = positive
+ self.negative = negative
+ self.neutral = neutral
+ self.width = width
+ self.height = height
+ self.action: int = action
+ self.multiplier = multiplier
+ self.weight = weight
+
+
+class PromptEmbedsCache: # 使いまわしたいので
+ prompts: dict[str, PromptEmbeds] = {}
+
+ def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
+ self.prompts[__name] = __value
+
+ def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
+ if __name in self.prompts:
+ return self.prompts[__name]
+ else:
+ return None
+
+
+class EncodedAnchor:
+ def __init__(
+ self,
+ prompt,
+ neg_prompt,
+ multiplier=1.0
+ ):
+ self.prompt = prompt
+ self.neg_prompt = neg_prompt
+ self.multiplier = multiplier
+
+
+class TrainSliderProcessOld(BaseSDTrainProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.step_num = 0
+ self.start_step = 0
+ self.device = self.get_conf('device', self.job.device)
+ self.device_torch = torch.device(self.device)
+ self.slider_config = SliderConfig(**self.get_conf('slider', {}))
+ self.prompt_cache = PromptEmbedsCache()
+ self.prompt_pairs: list[EncodedPromptPair] = []
+ self.anchor_pairs: list[EncodedAnchor] = []
+
+ def before_model_load(self):
+ pass
+
+ def hook_before_train_loop(self):
+ cache = PromptEmbedsCache()
+ prompt_pairs: list[EncodedPromptPair] = []
+
+ # get encoded latents for our prompts
+ with torch.no_grad():
+ neutral = ""
+ for target in self.slider_config.targets:
+ # build the cache
+ for prompt in [
+ target.target_class,
+ target.positive,
+ target.negative,
+ neutral # empty neutral
+ ]:
+ if cache[prompt] is None:
+ cache[prompt] = self.sd.encode_prompt(prompt)
+ for resolution in self.slider_config.resolutions:
+ width, height = resolution
+ only_erase = len(target.positive.strip()) == 0
+ only_enhance = len(target.negative.strip()) == 0
+
+ both = not only_erase and not only_enhance
+
+ if only_erase and only_enhance:
+ raise ValueError("target must have at least one of positive or negative or both")
+ # for slider we need to have an enhancer, an eraser, and then
+ # an inverse with negative weights to balance the network
+ # if we don't do this, we will get different contrast and focus.
+ # we only perform actions of enhancing and erasing on the negative
+ # todo work on way to do all of this in one shot
+
+ if both or only_erase:
+ prompt_pairs += [
+ # erase standard
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ positive=cache[target.positive],
+ negative=cache[target.negative],
+ neutral=cache[neutral],
+ width=width,
+ height=height,
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ multiplier=target.multiplier,
+ weight=target.weight
+ ),
+ ]
+ if both or only_enhance:
+ prompt_pairs += [
+ # enhance standard, swap pos neg
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ positive=cache[target.negative],
+ negative=cache[target.positive],
+ neutral=cache[neutral],
+ width=width,
+ height=height,
+ action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
+ multiplier=target.multiplier,
+ weight=target.weight
+ ),
+ ]
+ if both:
+ prompt_pairs += [
+ # erase inverted
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ positive=cache[target.negative],
+ negative=cache[target.positive],
+ neutral=cache[neutral],
+ width=width,
+ height=height,
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ multiplier=target.multiplier * -1.0,
+ weight=target.weight
+ ),
+ ]
+ prompt_pairs += [
+ # enhance inverted
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ positive=cache[target.positive],
+ negative=cache[target.negative],
+ neutral=cache[neutral],
+ width=width,
+ height=height,
+ action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
+ multiplier=target.multiplier * -1.0,
+ weight=target.weight
+ ),
+ ]
+
+ # setup anchors
+ anchor_pairs = []
+ for anchor in self.slider_config.anchors:
+ # build the cache
+ for prompt in [
+ anchor.prompt,
+ anchor.neg_prompt # empty neutral
+ ]:
+ if cache[prompt] == None:
+ cache[prompt] = self.sd.encode_prompt(prompt)
+
+ anchor_pairs += [
+ EncodedAnchor(
+ prompt=cache[anchor.prompt],
+ neg_prompt=cache[anchor.neg_prompt],
+ multiplier=anchor.multiplier
+ )
+ ]
+
+ # move to cpu to save vram
+ # We don't need text encoder anymore, but keep it on cpu for sampling
+ # if text encoder is list
+ if isinstance(self.sd.text_encoder, list):
+ for encoder in self.sd.text_encoder:
+ encoder.to("cpu")
+ else:
+ self.sd.text_encoder.to("cpu")
+ self.prompt_cache = cache
+ self.prompt_pairs = prompt_pairs
+ self.anchor_pairs = anchor_pairs
+ flush()
+ # end hook_before_train_loop
+
+ def hook_train_loop(self, batch):
+ dtype = get_torch_dtype(self.train_config.dtype)
+
+ # get a random pair
+ prompt_pair: EncodedPromptPair = self.prompt_pairs[
+ torch.randint(0, len(self.prompt_pairs), (1,)).item()
+ ]
+
+ height = prompt_pair.height
+ width = prompt_pair.width
+ target_class = prompt_pair.target_class
+ neutral = prompt_pair.neutral
+ negative = prompt_pair.negative
+ positive = prompt_pair.positive
+ weight = prompt_pair.weight
+ multiplier = prompt_pair.multiplier
+
+ unet = self.sd.unet
+ noise_scheduler = self.sd.noise_scheduler
+ optimizer = self.optimizer
+ lr_scheduler = self.lr_scheduler
+ loss_function = torch.nn.MSELoss()
+
+ def get_noise_pred(p, n, gs, cts, dn):
+ return self.sd.predict_noise(
+ latents=dn,
+ text_embeddings=train_tools.concat_prompt_embeddings(
+ p, # unconditional
+ n, # positive
+ self.train_config.batch_size,
+ ),
+ timestep=cts,
+ guidance_scale=gs,
+ )
+
+ # set network multiplier
+ self.network.multiplier = multiplier
+
+ with torch.no_grad():
+ self.sd.noise_scheduler.set_timesteps(
+ self.train_config.max_denoising_steps, device=self.device_torch
+ )
+
+ self.optimizer.zero_grad()
+
+ # ger a random number of steps
+ timesteps_to = torch.randint(
+ 1, self.train_config.max_denoising_steps, (1,)
+ ).item()
+
+ # get noise
+ noise = self.sd.get_latent_noise(
+ pixel_height=height,
+ pixel_width=width,
+ batch_size=self.train_config.batch_size,
+ noise_offset=self.train_config.noise_offset,
+ ).to(self.device_torch, dtype=dtype)
+
+ # get latents
+ latents = noise * self.sd.noise_scheduler.init_noise_sigma
+ latents = latents.to(self.device_torch, dtype=dtype)
+
+ with self.network:
+ assert self.network.is_active
+ self.network.multiplier = multiplier
+ denoised_latents = self.sd.diffuse_some_steps(
+ latents, # pass simple noise latents
+ train_tools.concat_prompt_embeddings(
+ positive, # unconditional
+ target_class, # target
+ self.train_config.batch_size,
+ ),
+ start_timesteps=0,
+ total_timesteps=timesteps_to,
+ guidance_scale=3,
+ )
+
+ noise_scheduler.set_timesteps(1000)
+
+ current_timestep = noise_scheduler.timesteps[
+ int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
+ ]
+
+ positive_latents = get_noise_pred(
+ positive, negative, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+
+ neutral_latents = get_noise_pred(
+ positive, neutral, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+
+ unconditional_latents = get_noise_pred(
+ positive, positive, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+
+ anchor_loss = None
+ if len(self.anchor_pairs) > 0:
+ # get a random anchor pair
+ anchor: EncodedAnchor = self.anchor_pairs[
+ torch.randint(0, len(self.anchor_pairs), (1,)).item()
+ ]
+ with torch.no_grad():
+ anchor_target_noise = get_noise_pred(
+ anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+ with self.network:
+ # anchor whatever weight prompt pair is using
+ pos_nem_mult = 1.0 if prompt_pair.multiplier > 0 else -1.0
+ self.network.multiplier = anchor.multiplier * pos_nem_mult
+
+ anchor_pred_noise = get_noise_pred(
+ anchor.prompt, anchor.neg_prompt, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+
+ self.network.multiplier = prompt_pair.multiplier
+
+ with self.network:
+ self.network.multiplier = prompt_pair.multiplier
+ target_latents = get_noise_pred(
+ positive, target_class, 1, current_timestep, denoised_latents
+ ).to("cpu", dtype=torch.float32)
+
+ # if self.logging_config.verbose:
+ # self.print("target_latents:", target_latents[0, 0, :5, :5])
+
+ positive_latents.requires_grad = False
+ neutral_latents.requires_grad = False
+ unconditional_latents.requires_grad = False
+ if len(self.anchor_pairs) > 0:
+ anchor_target_noise.requires_grad = False
+ anchor_loss = loss_function(
+ anchor_target_noise,
+ anchor_pred_noise,
+ )
+ erase = prompt_pair.action == ACTION_TYPES_SLIDER.ERASE_NEGATIVE
+ guidance_scale = 1.0
+
+ offset = guidance_scale * (positive_latents - unconditional_latents)
+
+ offset_neutral = neutral_latents
+ if erase:
+ offset_neutral -= offset
+ else:
+ # enhance
+ offset_neutral += offset
+
+ loss = loss_function(
+ target_latents,
+ offset_neutral,
+ ) * weight
+
+ loss_slide = loss.item()
+
+ if anchor_loss is not None:
+ loss += anchor_loss
+
+ loss_float = loss.item()
+
+ loss = loss.to(self.device_torch)
+
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+
+ del (
+ positive_latents,
+ neutral_latents,
+ unconditional_latents,
+ target_latents,
+ latents,
+ )
+ flush()
+
+ # reset network
+ self.network.multiplier = 1.0
+
+ loss_dict = OrderedDict(
+ {'loss': loss_float},
+ )
+ if anchor_loss is not None:
+ loss_dict['sl_l'] = loss_slide
+ loss_dict['an_l'] = anchor_loss.item()
+
+ return loss_dict
+ # end hook_train_loop
diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb6536cdbc62d23104a465623bb7745bd4df00b5
--- /dev/null
+++ b/jobs/process/TrainVAEProcess.py
@@ -0,0 +1,612 @@
+import copy
+import glob
+import os
+import shutil
+import time
+from collections import OrderedDict
+
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import save_file, load_file
+from torch.utils.data import DataLoader, ConcatDataset
+import torch
+from torch import nn
+from torchvision.transforms import transforms
+
+from jobs.process import BaseTrainProcess
+from toolkit.image_utils import show_tensors
+from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
+from toolkit.data_loader import ImageDataset
+from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.optimizer import get_optimizer
+from toolkit.style import get_style_model_and_losses
+from toolkit.train_tools import get_torch_dtype
+from diffusers import AutoencoderKL
+from tqdm import tqdm
+import time
+import numpy as np
+from .models.vgg19_critic import Critic
+from torchvision.transforms import Resize
+import lpips
+
+IMAGE_TRANSFORMS = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+)
+
+
+def unnormalize(tensor):
+ return (tensor / 2 + 0.5).clamp(0, 1)
+
+
+class TrainVAEProcess(BaseTrainProcess):
+ def __init__(self, process_id: int, job, config: OrderedDict):
+ super().__init__(process_id, job, config)
+ self.data_loader = None
+ self.vae = None
+ self.device = self.get_conf('device', self.job.device)
+ self.vae_path = self.get_conf('vae_path', required=True)
+ self.datasets_objects = self.get_conf('datasets', required=True)
+ self.batch_size = self.get_conf('batch_size', 1, as_type=int)
+ self.resolution = self.get_conf('resolution', 256, as_type=int)
+ self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
+ self.sample_every = self.get_conf('sample_every', None)
+ self.optimizer_type = self.get_conf('optimizer', 'adam')
+ self.epochs = self.get_conf('epochs', None, as_type=int)
+ self.max_steps = self.get_conf('max_steps', None, as_type=int)
+ self.save_every = self.get_conf('save_every', None)
+ self.dtype = self.get_conf('dtype', 'float32')
+ self.sample_sources = self.get_conf('sample_sources', None)
+ self.log_every = self.get_conf('log_every', 100, as_type=int)
+ self.style_weight = self.get_conf('style_weight', 0, as_type=float)
+ self.content_weight = self.get_conf('content_weight', 0, as_type=float)
+ self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
+ self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
+ self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
+ self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
+ self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
+ self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
+ self.optimizer_params = self.get_conf('optimizer_params', {})
+
+ self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ self.vgg_19 = None
+ self.style_weight_scalers = []
+ self.content_weight_scalers = []
+ self.lpips_loss:lpips.LPIPS = None
+
+ self.vae_scale_factor = 8
+
+ self.step_num = 0
+ self.epoch_num = 0
+
+ self.use_critic = self.get_conf('use_critic', False, as_type=bool)
+ self.critic = None
+
+ if self.use_critic:
+ self.critic = Critic(
+ device=self.device,
+ dtype=self.dtype,
+ process=self,
+ **self.get_conf('critic', {}) # pass any other params
+ )
+
+ if self.sample_every is not None and self.sample_sources is None:
+ raise ValueError('sample_every is specified but sample_sources is not')
+
+ if self.epochs is None and self.max_steps is None:
+ raise ValueError('epochs or max_steps must be specified')
+
+ self.data_loaders = []
+ # check datasets
+ assert isinstance(self.datasets_objects, list)
+ for dataset in self.datasets_objects:
+ if 'path' not in dataset:
+ raise ValueError('dataset must have a path')
+ # check if is dir
+ if not os.path.isdir(dataset['path']):
+ raise ValueError(f"dataset path does is not a directory: {dataset['path']}")
+
+ # make training folder
+ if not os.path.exists(self.save_root):
+ os.makedirs(self.save_root, exist_ok=True)
+
+ self._pattern_loss = None
+
+ def update_training_metadata(self):
+ self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
+
+ def get_training_info(self):
+ info = OrderedDict({
+ 'step': self.step_num,
+ 'epoch': self.epoch_num,
+ })
+ return info
+
+ def load_datasets(self):
+ if self.data_loader is None:
+ print(f"Loading datasets")
+ datasets = []
+ for dataset in self.datasets_objects:
+ print(f" - Dataset: {dataset['path']}")
+ ds = copy.copy(dataset)
+ ds['resolution'] = self.resolution
+ image_dataset = ImageDataset(ds)
+ datasets.append(image_dataset)
+
+ concatenated_dataset = ConcatDataset(datasets)
+ self.data_loader = DataLoader(
+ concatenated_dataset,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=6
+ )
+
+ def remove_oldest_checkpoint(self):
+ max_to_keep = 4
+ folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
+ if len(folders) > max_to_keep:
+ folders.sort(key=os.path.getmtime)
+ for folder in folders[:-max_to_keep]:
+ print(f"Removing {folder}")
+ shutil.rmtree(folder)
+
+ def setup_vgg19(self):
+ if self.vgg_19 is None:
+ self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
+ single_target=True,
+ device=self.device,
+ output_layer_name='pool_4',
+ dtype=self.torch_dtype
+ )
+ self.vgg_19.to(self.device, dtype=self.torch_dtype)
+ self.vgg_19.requires_grad_(False)
+
+ # we run random noise through first to get layer scalers to normalize the loss per layer
+ # bs of 2 because we run pred and target through stacked
+ noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
+ self.vgg_19(noise)
+ for style_loss in self.style_losses:
+ # get a scaler to normalize to 1
+ scaler = 1 / torch.mean(style_loss.loss).item()
+ self.style_weight_scalers.append(scaler)
+ for content_loss in self.content_losses:
+ # get a scaler to normalize to 1
+ scaler = 1 / torch.mean(content_loss.loss).item()
+ self.content_weight_scalers.append(scaler)
+
+ self.print(f"Style weight scalers: {self.style_weight_scalers}")
+ self.print(f"Content weight scalers: {self.content_weight_scalers}")
+
+ def get_style_loss(self):
+ if self.style_weight > 0:
+ # scale all losses with loss scalers
+ loss = torch.sum(
+ torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_content_loss(self):
+ if self.content_weight > 0:
+ # scale all losses with loss scalers
+ loss = torch.sum(torch.stack(
+ [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_mse_loss(self, pred, target):
+ if self.mse_weight > 0:
+ loss_fn = nn.MSELoss()
+ loss = loss_fn(pred, target)
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_kld_loss(self, mu, log_var):
+ if self.kld_weight > 0:
+ # Kullback-Leibler divergence
+ # added here for full training (not implemented). Not needed for only decoder
+ # as we are not changing the distribution of the latent space
+ # normally it would help keep a normal distribution for latents
+ KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence
+ return KLD
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_tv_loss(self, pred, target):
+ if self.tv_weight > 0:
+ get_tv_loss = ComparativeTotalVariation()
+ loss = get_tv_loss(pred, target)
+ return loss
+ else:
+ return torch.tensor(0.0, device=self.device)
+
+ def get_pattern_loss(self, pred, target):
+ if self._pattern_loss is None:
+ self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
+ dtype=self.torch_dtype)
+ loss = torch.mean(self._pattern_loss(pred, target))
+ return loss
+
+ def save(self, step=None):
+ if not os.path.exists(self.save_root):
+ os.makedirs(self.save_root, exist_ok=True)
+
+ step_num = ''
+ if step is not None:
+ # zeropad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+
+ self.update_training_metadata()
+ filename = f'{self.job.name}{step_num}_diffusers'
+
+ self.vae = self.vae.to("cpu", dtype=torch.float16)
+ self.vae.save_pretrained(
+ save_directory=os.path.join(self.save_root, filename)
+ )
+ self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
+
+ self.print(f"Saved to {os.path.join(self.save_root, filename)}")
+
+ if self.use_critic:
+ self.critic.save(step)
+
+ self.remove_oldest_checkpoint()
+
+ def sample(self, step=None):
+ sample_folder = os.path.join(self.save_root, 'samples')
+ if not os.path.exists(sample_folder):
+ os.makedirs(sample_folder, exist_ok=True)
+
+ with torch.no_grad():
+ for i, img_url in enumerate(self.sample_sources):
+ img = exif_transpose(Image.open(img_url))
+ img = img.convert('RGB')
+ # crop if not square
+ if img.width != img.height:
+ min_dim = min(img.width, img.height)
+ img = img.crop((0, 0, min_dim, min_dim))
+ # resize
+ img = img.resize((self.resolution, self.resolution))
+
+ input_img = img
+ img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
+ img = img
+ decoded = self.vae(img).sample
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
+
+ # convert to pillow image
+ decoded = Image.fromarray((decoded * 255).astype(np.uint8))
+
+ # stack input image and decoded image
+ input_img = input_img.resize((self.resolution, self.resolution))
+ decoded = decoded.resize((self.resolution, self.resolution))
+
+ output_img = Image.new('RGB', (self.resolution * 2, self.resolution))
+ output_img.paste(input_img, (0, 0))
+ output_img.paste(decoded, (self.resolution, 0))
+
+ scale_up = 2
+ if output_img.height <= 300:
+ scale_up = 4
+
+ # scale up using nearest neighbor
+ output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
+
+ step_num = ''
+ if step is not None:
+ # zero-pad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+ seconds_since_epoch = int(time.time())
+ # zero-pad 2 digits
+ i_str = str(i).zfill(2)
+ filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
+ output_img.save(os.path.join(sample_folder, filename))
+
+ def load_vae(self):
+ path_to_load = self.vae_path
+ # see if we have a checkpoint in out output to resume from
+ self.print(f"Looking for latest checkpoint in {self.save_root}")
+ files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
+ if files and len(files) > 0:
+ latest_file = max(files, key=os.path.getmtime)
+ print(f" - Latest checkpoint is: {latest_file}")
+ path_to_load = latest_file
+ # todo update step and epoch count
+ else:
+ self.print(f" - No checkpoint found, starting from scratch")
+ # load vae
+ self.print(f"Loading VAE")
+ self.print(f" - Loading VAE: {path_to_load}")
+ if self.vae is None:
+ self.vae = AutoencoderKL.from_pretrained(path_to_load)
+
+ # set decoder to train
+ self.vae.to(self.device, dtype=self.torch_dtype)
+ self.vae.requires_grad_(False)
+ self.vae.eval()
+ self.vae.decoder.train()
+ self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+
+ def run(self):
+ super().run()
+ self.load_datasets()
+
+ max_step_epochs = self.max_steps // len(self.data_loader)
+ num_epochs = self.epochs
+ if num_epochs is None or num_epochs > max_step_epochs:
+ num_epochs = max_step_epochs
+
+ max_epoch_steps = len(self.data_loader) * num_epochs
+ num_steps = self.max_steps
+ if num_steps is None or num_steps > max_epoch_steps:
+ num_steps = max_epoch_steps
+ self.max_steps = num_steps
+ self.epochs = num_epochs
+ start_step = self.step_num
+ self.first_step = start_step
+
+ self.print(f"Training VAE")
+ self.print(f" - Training folder: {self.training_folder}")
+ self.print(f" - Batch size: {self.batch_size}")
+ self.print(f" - Learning rate: {self.learning_rate}")
+ self.print(f" - Epochs: {num_epochs}")
+ self.print(f" - Max steps: {self.max_steps}")
+
+ # load vae
+ self.load_vae()
+
+ params = []
+
+ # only set last 2 layers to trainable
+ for param in self.vae.decoder.parameters():
+ param.requires_grad = False
+
+ train_all = 'all' in self.blocks_to_train
+
+ if train_all:
+ params = list(self.vae.decoder.parameters())
+ self.vae.decoder.requires_grad_(True)
+ else:
+ # mid_block
+ if train_all or 'mid_block' in self.blocks_to_train:
+ params += list(self.vae.decoder.mid_block.parameters())
+ self.vae.decoder.mid_block.requires_grad_(True)
+ # up_blocks
+ if train_all or 'up_blocks' in self.blocks_to_train:
+ params += list(self.vae.decoder.up_blocks.parameters())
+ self.vae.decoder.up_blocks.requires_grad_(True)
+ # conv_out (single conv layer output)
+ if train_all or 'conv_out' in self.blocks_to_train:
+ params += list(self.vae.decoder.conv_out.parameters())
+ self.vae.decoder.conv_out.requires_grad_(True)
+
+ if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
+ self.setup_vgg19()
+ self.vgg_19.requires_grad_(False)
+ self.vgg_19.eval()
+ if self.use_critic:
+ self.critic.setup()
+
+ if self.lpips_weight > 0 and self.lpips_loss is None:
+ # self.lpips_loss = lpips.LPIPS(net='vgg')
+ self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
+
+ optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
+ optimizer_params=self.optimizer_params)
+
+ # setup scheduler
+ # todo allow other schedulers
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
+ optimizer,
+ total_iters=num_steps,
+ factor=1,
+ verbose=False
+ )
+
+ # setup tqdm progress bar
+ self.progress_bar = tqdm(
+ total=num_steps,
+ desc='Training VAE',
+ leave=True
+ )
+
+ # sample first
+ self.sample()
+ blank_losses = OrderedDict({
+ "total": [],
+ "lpips": [],
+ "style": [],
+ "content": [],
+ "mse": [],
+ "kl": [],
+ "tv": [],
+ "ptn": [],
+ "crD": [],
+ "crG": [],
+ })
+ epoch_losses = copy.deepcopy(blank_losses)
+ log_losses = copy.deepcopy(blank_losses)
+ # range start at self.epoch_num go to self.epochs
+ for epoch in range(self.epoch_num, self.epochs, 1):
+ if self.step_num >= self.max_steps:
+ break
+ for batch in self.data_loader:
+ if self.step_num >= self.max_steps:
+ break
+ with torch.no_grad():
+
+ batch = batch.to(self.device, dtype=self.torch_dtype)
+
+ # resize so it matches size of vae evenly
+ if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
+ batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
+ batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
+
+ # forward pass
+ dgd = self.vae.encode(batch).latent_dist
+ mu, logvar = dgd.mean, dgd.logvar
+ latents = dgd.sample()
+ latents.detach().requires_grad_(True)
+
+ pred = self.vae.decode(latents).sample
+
+ with torch.no_grad():
+ show_tensors(
+ pred.clamp(-1, 1).clone(),
+ "combined tensor"
+ )
+
+ # Run through VGG19
+ if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
+ stacked = torch.cat([pred, batch], dim=0)
+ stacked = (stacked / 2 + 0.5).clamp(0, 1)
+ self.vgg_19(stacked)
+
+ if self.use_critic:
+ critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
+ else:
+ critic_d_loss = 0.0
+
+ style_loss = self.get_style_loss() * self.style_weight
+ content_loss = self.get_content_loss() * self.content_weight
+ kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
+ mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
+ if self.lpips_weight > 0:
+ lpips_loss = self.lpips_loss(
+ pred.clamp(-1, 1),
+ batch.clamp(-1, 1)
+ ).mean() * self.lpips_weight
+ else:
+ lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
+ tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
+ pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
+ if self.use_critic:
+ critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
+
+ # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
+ if self.lpips_weight > 0:
+ max_target = lpips_loss.abs() * 0.1
+ with torch.no_grad():
+ crit_g_scaler = 1.0
+ if critic_gen_loss.abs() > max_target:
+ crit_g_scaler = max_target / critic_gen_loss.abs()
+
+ critic_gen_loss *= crit_g_scaler
+ else:
+ critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
+
+ loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
+
+ # Backward pass and optimization
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ # update progress bar
+ loss_value = loss.item()
+ # get exponent like 3.54e-4
+ loss_string = f"loss: {loss_value:.2e}"
+ if self.lpips_weight > 0:
+ loss_string += f" lpips: {lpips_loss.item():.2e}"
+ if self.content_weight > 0:
+ loss_string += f" cnt: {content_loss.item():.2e}"
+ if self.style_weight > 0:
+ loss_string += f" sty: {style_loss.item():.2e}"
+ if self.kld_weight > 0:
+ loss_string += f" kld: {kld_loss.item():.2e}"
+ if self.mse_weight > 0:
+ loss_string += f" mse: {mse_loss.item():.2e}"
+ if self.tv_weight > 0:
+ loss_string += f" tv: {tv_loss.item():.2e}"
+ if self.pattern_weight > 0:
+ loss_string += f" ptn: {pattern_loss.item():.2e}"
+ if self.use_critic and self.critic_weight > 0:
+ loss_string += f" crG: {critic_gen_loss.item():.2e}"
+ if self.use_critic:
+ loss_string += f" crD: {critic_d_loss:.2e}"
+
+ if self.optimizer_type.startswith('dadaptation') or \
+ self.optimizer_type.lower().startswith('prodigy'):
+ learning_rate = (
+ optimizer.param_groups[0]["d"] *
+ optimizer.param_groups[0]["lr"]
+ )
+ else:
+ learning_rate = optimizer.param_groups[0]['lr']
+
+ lr_critic_string = ''
+ if self.use_critic:
+ lr_critic = self.critic.get_lr()
+ lr_critic_string = f" lrC: {lr_critic:.1e}"
+
+ self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
+ self.progress_bar.set_description(f"E: {epoch}")
+ self.progress_bar.update(1)
+
+ epoch_losses["total"].append(loss_value)
+ epoch_losses["lpips"].append(lpips_loss.item())
+ epoch_losses["style"].append(style_loss.item())
+ epoch_losses["content"].append(content_loss.item())
+ epoch_losses["mse"].append(mse_loss.item())
+ epoch_losses["kl"].append(kld_loss.item())
+ epoch_losses["tv"].append(tv_loss.item())
+ epoch_losses["ptn"].append(pattern_loss.item())
+ epoch_losses["crG"].append(critic_gen_loss.item())
+ epoch_losses["crD"].append(critic_d_loss)
+
+ log_losses["total"].append(loss_value)
+ log_losses["lpips"].append(lpips_loss.item())
+ log_losses["style"].append(style_loss.item())
+ log_losses["content"].append(content_loss.item())
+ log_losses["mse"].append(mse_loss.item())
+ log_losses["kl"].append(kld_loss.item())
+ log_losses["tv"].append(tv_loss.item())
+ log_losses["ptn"].append(pattern_loss.item())
+ log_losses["crG"].append(critic_gen_loss.item())
+ log_losses["crD"].append(critic_d_loss)
+
+ # don't do on first step
+ if self.step_num != start_step:
+ if self.sample_every and self.step_num % self.sample_every == 0:
+ # print above the progress bar
+ self.print(f"Sampling at step {self.step_num}")
+ self.sample(self.step_num)
+
+ if self.save_every and self.step_num % self.save_every == 0:
+ # print above the progress bar
+ self.print(f"Saving at step {self.step_num}")
+ self.save(self.step_num)
+
+ if self.log_every and self.step_num % self.log_every == 0:
+ # log to tensorboard
+ if self.writer is not None:
+ # get avg loss
+ for key in log_losses:
+ log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
+ # if log_losses[key] > 0:
+ self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
+ # reset log losses
+ log_losses = copy.deepcopy(blank_losses)
+
+ self.step_num += 1
+ # end epoch
+ if self.writer is not None:
+ eps = 1e-6
+ # get avg loss
+ for key in epoch_losses:
+ epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
+ if epoch_losses[key] > 0:
+ self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
+ # reset epoch losses
+ epoch_losses = copy.deepcopy(blank_losses)
+
+ self.save()
diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..387be08853c3bcb5cbc551d0b1c1c99dad124df6
--- /dev/null
+++ b/jobs/process/__init__.py
@@ -0,0 +1,15 @@
+from .BaseExtractProcess import BaseExtractProcess
+from .ExtractLoconProcess import ExtractLoconProcess
+from .ExtractLoraProcess import ExtractLoraProcess
+from .BaseProcess import BaseProcess
+from .BaseTrainProcess import BaseTrainProcess
+from .TrainVAEProcess import TrainVAEProcess
+from .BaseMergeProcess import BaseMergeProcess
+from .TrainSliderProcess import TrainSliderProcess
+from .TrainSliderProcessOld import TrainSliderProcessOld
+from .TrainSDRescaleProcess import TrainSDRescaleProcess
+from .ModRescaleLoraProcess import ModRescaleLoraProcess
+from .GenerateProcess import GenerateProcess
+from .BaseExtensionProcess import BaseExtensionProcess
+from .TrainESRGANProcess import TrainESRGANProcess
+from .BaseSDTrainProcess import BaseSDTrainProcess
diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf438bf11487d3daff86266f515f477dcaf88cd
--- /dev/null
+++ b/jobs/process/models/vgg19_critic.py
@@ -0,0 +1,194 @@
+import glob
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from safetensors.torch import load_file, save_file
+
+from toolkit.losses import get_gradient_penalty
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.optimizer import get_optimizer
+from toolkit.train_tools import get_torch_dtype
+
+from typing import TYPE_CHECKING, Union
+
+
+class MeanReduce(nn.Module):
+ def __init__(self):
+ super(MeanReduce, self).__init__()
+
+ def forward(self, inputs):
+ return torch.mean(inputs, dim=(1, 2, 3), keepdim=True)
+
+
+class Vgg19Critic(nn.Module):
+ def __init__(self):
+ # vgg19 input (bs, 3, 512, 512)
+ # pool1 (bs, 64, 256, 256)
+ # pool2 (bs, 128, 128, 128)
+ # pool3 (bs, 256, 64, 64)
+ # pool4 (bs, 512, 32, 32) <- take this input
+
+ super(Vgg19Critic, self).__init__()
+ self.main = nn.Sequential(
+ # input (bs, 512, 32, 32)
+ nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
+ nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
+ nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
+ # (bs, 1, 4, 4)
+ MeanReduce(), # (bs, 1, 1, 1)
+ nn.Flatten(), # (bs, 1)
+
+ # nn.Flatten(), # (128*8*8) = 8192
+ # nn.Linear(128 * 8 * 8, 1)
+ )
+
+ def forward(self, inputs):
+ return self.main(inputs)
+
+
+if TYPE_CHECKING:
+ from jobs.process.TrainVAEProcess import TrainVAEProcess
+ from jobs.process.TrainESRGANProcess import TrainESRGANProcess
+
+
+class Critic:
+ process: Union['TrainVAEProcess', 'TrainESRGANProcess']
+
+ def __init__(
+ self,
+ learning_rate=1e-5,
+ device='cpu',
+ optimizer='adam',
+ num_critic_per_gen=1,
+ dtype='float32',
+ lambda_gp=10,
+ start_step=0,
+ warmup_steps=1000,
+ process=None,
+ optimizer_params=None,
+ ):
+ self.learning_rate = learning_rate
+ self.device = device
+ self.optimizer_type = optimizer
+ self.num_critic_per_gen = num_critic_per_gen
+ self.dtype = dtype
+ self.torch_dtype = get_torch_dtype(self.dtype)
+ self.process = process
+ self.model = None
+ self.optimizer = None
+ self.scheduler = None
+ self.warmup_steps = warmup_steps
+ self.start_step = start_step
+ self.lambda_gp = lambda_gp
+
+ if optimizer_params is None:
+ optimizer_params = {}
+ self.optimizer_params = optimizer_params
+ self.print = self.process.print
+ print(f" Critic config: {self.__dict__}")
+
+ def setup(self):
+ self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
+ self.load_weights()
+ self.model.train()
+ self.model.requires_grad_(True)
+ params = self.model.parameters()
+ self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
+ optimizer_params=self.optimizer_params)
+ self.scheduler = torch.optim.lr_scheduler.ConstantLR(
+ self.optimizer,
+ total_iters=self.process.max_steps * self.num_critic_per_gen,
+ factor=1,
+ verbose=False
+ )
+
+ def load_weights(self):
+ path_to_load = None
+ self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
+ files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
+ if files and len(files) > 0:
+ latest_file = max(files, key=os.path.getmtime)
+ print(f" - Latest checkpoint is: {latest_file}")
+ path_to_load = latest_file
+ else:
+ self.print(f" - No checkpoint found, starting from scratch")
+ if path_to_load:
+ self.model.load_state_dict(load_file(path_to_load))
+
+ def save(self, step=None):
+ self.process.update_training_metadata()
+ save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
+ step_num = ''
+ if step is not None:
+ # zeropad 9 digits
+ step_num = f"_{str(step).zfill(9)}"
+ save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
+ save_file(self.model.state_dict(), save_path, save_meta)
+ self.print(f"Saved critic to {save_path}")
+
+ def get_critic_loss(self, vgg_output):
+ if self.start_step > self.process.step_num:
+ return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
+
+ warmup_scaler = 1.0
+ # we need a warmup when we come on of 1000 steps
+ # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
+ if self.process.step_num < self.start_step + self.warmup_steps:
+ warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
+ # set model to not train for generator loss
+ self.model.eval()
+ self.model.requires_grad_(False)
+ vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
+
+ # run model
+ stacked_output = self.model(vgg_pred)
+
+ return (-torch.mean(stacked_output)) * warmup_scaler
+
+ def step(self, vgg_output):
+
+ # train critic here
+ self.model.train()
+ self.model.requires_grad_(True)
+ self.optimizer.zero_grad()
+
+ critic_losses = []
+ inputs = vgg_output.detach()
+ inputs = inputs.to(self.device, dtype=self.torch_dtype)
+ self.optimizer.zero_grad()
+
+ vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
+
+ stacked_output = self.model(inputs).float()
+ out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
+
+ # Compute gradient penalty
+ gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
+
+ # Compute WGAN-GP critic loss
+ critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
+ critic_loss.backward()
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
+ self.optimizer.step()
+ self.scheduler.step()
+ critic_losses.append(critic_loss.item())
+
+ # avg loss
+ loss = np.mean(critic_losses)
+ return loss
+
+ def get_lr(self):
+ if self.optimizer_type.startswith('dadaptation'):
+ learning_rate = (
+ self.optimizer.param_groups[0]["d"] *
+ self.optimizer.param_groups[0]["lr"]
+ )
+ else:
+ learning_rate = self.optimizer.param_groups[0]['lr']
+
+ return learning_rate
+
diff --git a/notebooks/FLUX_1_dev_LoRA_Training.ipynb b/notebooks/FLUX_1_dev_LoRA_Training.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8cfcd1fedfc941ac1a050f39499f77d303e23783
--- /dev/null
+++ b/notebooks/FLUX_1_dev_LoRA_Training.ipynb
@@ -0,0 +1,291 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": false,
+ "id": "zl-S0m3pkQC5"
+ },
+ "source": [
+ "# AI Toolkit by Ostris\n",
+ "## FLUX.1-dev Training\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BvAG0GKAh59G"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/ostris/ai-toolkit\n",
+ "!mkdir -p /content/dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UFUW4ZMmnp1V"
+ },
+ "source": [
+ "Put your image dataset in the `/content/dataset` folder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XGZqVER_aQJW"
+ },
+ "outputs": [],
+ "source": [
+ "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OV0HnOI6o8V6"
+ },
+ "source": [
+ "## Model License\n",
+ "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n",
+ "\n",
+ "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n",
+ "\n",
+ "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3yZZdhFRoj2m"
+ },
+ "outputs": [],
+ "source": [
+ "import getpass\n",
+ "import os\n",
+ "\n",
+ "# Prompt for the token\n",
+ "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n",
+ "\n",
+ "# Set the environment variable\n",
+ "os.environ['HF_TOKEN'] = hf_token\n",
+ "\n",
+ "print(\"HF_TOKEN environment variable has been set.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "9gO2EzQ1kQC8"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "sys.path.append('/content/ai-toolkit')\n",
+ "from toolkit.job import run_job\n",
+ "from collections import OrderedDict\n",
+ "from PIL import Image\n",
+ "import os\n",
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N8UUFzVRigbC"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_t28QURYjRQO"
+ },
+ "outputs": [],
+ "source": [
+ "from collections import OrderedDict\n",
+ "\n",
+ "job_to_run = OrderedDict([\n",
+ " ('job', 'extension'),\n",
+ " ('config', OrderedDict([\n",
+ " # this name will be the folder and filename name\n",
+ " ('name', 'my_first_flux_lora_v1'),\n",
+ " ('process', [\n",
+ " OrderedDict([\n",
+ " ('type', 'sd_trainer'),\n",
+ " # root folder to save training sessions/samples/weights\n",
+ " ('training_folder', '/content/output'),\n",
+ " # uncomment to see performance stats in the terminal every N steps\n",
+ " #('performance_log_every', 1000),\n",
+ " ('device', 'cuda:0'),\n",
+ " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n",
+ " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n",
+ " # ('trigger_word', 'image'),\n",
+ " ('network', OrderedDict([\n",
+ " ('type', 'lora'),\n",
+ " ('linear', 16),\n",
+ " ('linear_alpha', 16)\n",
+ " ])),\n",
+ " ('save', OrderedDict([\n",
+ " ('dtype', 'float16'), # precision to save\n",
+ " ('save_every', 250), # save every this many steps\n",
+ " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n",
+ " ])),\n",
+ " ('datasets', [\n",
+ " # datasets are a folder of images. captions need to be txt files with the same name as the image\n",
+ " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n",
+ " # images will automatically be resized and bucketed into the resolution specified\n",
+ " OrderedDict([\n",
+ " ('folder_path', '/content/dataset'),\n",
+ " ('caption_ext', 'txt'),\n",
+ " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n",
+ " ('shuffle_tokens', False), # shuffle caption order, split by commas\n",
+ " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n",
+ " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n",
+ " ])\n",
+ " ]),\n",
+ " ('train', OrderedDict([\n",
+ " ('batch_size', 1),\n",
+ " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n",
+ " ('gradient_accumulation_steps', 1),\n",
+ " ('train_unet', True),\n",
+ " ('train_text_encoder', False), # probably won't work with flux\n",
+ " ('content_or_style', 'balanced'), # content, style, balanced\n",
+ " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n",
+ " ('noise_scheduler', 'flowmatch'), # for training only\n",
+ " ('optimizer', 'adamw8bit'),\n",
+ " ('lr', 1e-4),\n",
+ "\n",
+ " # uncomment this to skip the pre training sample\n",
+ " # ('skip_first_sample', True),\n",
+ "\n",
+ " # uncomment to completely disable sampling\n",
+ " # ('disable_sampling', True),\n",
+ "\n",
+ " # uncomment to use new vell curved weighting. Experimental but may produce better results\n",
+ " # ('linear_timesteps', True),\n",
+ "\n",
+ " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n",
+ " ('ema_config', OrderedDict([\n",
+ " ('use_ema', True),\n",
+ " ('ema_decay', 0.99)\n",
+ " ])),\n",
+ "\n",
+ " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n",
+ " ('dtype', 'bf16')\n",
+ " ])),\n",
+ " ('model', OrderedDict([\n",
+ " # huggingface model name or path\n",
+ " ('name_or_path', 'black-forest-labs/FLUX.1-dev'),\n",
+ " ('is_flux', True),\n",
+ " ('quantize', True), # run 8bit mixed precision\n",
+ " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n",
+ " ])),\n",
+ " ('sample', OrderedDict([\n",
+ " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n",
+ " ('sample_every', 250), # sample every this many steps\n",
+ " ('width', 1024),\n",
+ " ('height', 1024),\n",
+ " ('prompts', [\n",
+ " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n",
+ " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n",
+ " 'woman with red hair, playing chess at the park, bomb going off in the background',\n",
+ " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n",
+ " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n",
+ " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n",
+ " 'a bear building a log cabin in the snow covered mountains',\n",
+ " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n",
+ " 'hipster man with a beard, building a chair, in a wood shop',\n",
+ " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n",
+ " 'a man holding a sign that says, \\'this is a sign\\'',\n",
+ " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n",
+ " ]),\n",
+ " ('neg', ''), # not used on flux\n",
+ " ('seed', 42),\n",
+ " ('walk_seed', True),\n",
+ " ('guidance_scale', 4),\n",
+ " ('sample_steps', 20)\n",
+ " ]))\n",
+ " ])\n",
+ " ])\n",
+ " ])),\n",
+ " # you can add any additional meta info here. [name] is replaced with config name at top\n",
+ " ('meta', OrderedDict([\n",
+ " ('name', '[name]'),\n",
+ " ('version', '1.0')\n",
+ " ]))\n",
+ "])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h6F1FlM2Wb3l"
+ },
+ "source": [
+ "## Run it\n",
+ "\n",
+ "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HkajwI8gteOh"
+ },
+ "outputs": [],
+ "source": [
+ "run_job(job_to_run)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Hblgb5uwW5SD"
+ },
+ "source": [
+ "## Done\n",
+ "\n",
+ "Check your ourput dir and get your slider\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "A100",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/notebooks/FLUX_1_schnell_LoRA_Training.ipynb b/notebooks/FLUX_1_schnell_LoRA_Training.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..652d8ccc19f8996734785182ce8de46a5c7408fb
--- /dev/null
+++ b/notebooks/FLUX_1_schnell_LoRA_Training.ipynb
@@ -0,0 +1,296 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "collapsed": false,
+ "id": "zl-S0m3pkQC5"
+ },
+ "source": [
+ "# AI Toolkit by Ostris\n",
+ "## FLUX.1-schnell Training\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3cokMT-WC6rG"
+ },
+ "outputs": [],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "BvAG0GKAh59G"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/ostris/ai-toolkit\n",
+ "!mkdir -p /content/dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UFUW4ZMmnp1V"
+ },
+ "source": [
+ "Put your image dataset in the `/content/dataset` folder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "XGZqVER_aQJW"
+ },
+ "outputs": [],
+ "source": [
+ "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OV0HnOI6o8V6"
+ },
+ "source": [
+ "## Model License\n",
+ "Training currently only works with FLUX.1-dev. Which means anything you train will inherit the non-commercial license. It is also a gated model, so you need to accept the license on HF before using it. Otherwise, this will fail. Here are the required steps to setup a license.\n",
+ "\n",
+ "Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\n",
+ "\n",
+ "[Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and place it in the next cell after running it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3yZZdhFRoj2m"
+ },
+ "outputs": [],
+ "source": [
+ "import getpass\n",
+ "import os\n",
+ "\n",
+ "# Prompt for the token\n",
+ "hf_token = getpass.getpass('Enter your HF access token and press enter: ')\n",
+ "\n",
+ "# Set the environment variable\n",
+ "os.environ['HF_TOKEN'] = hf_token\n",
+ "\n",
+ "print(\"HF_TOKEN environment variable has been set.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "9gO2EzQ1kQC8"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "sys.path.append('/content/ai-toolkit')\n",
+ "from toolkit.job import run_job\n",
+ "from collections import OrderedDict\n",
+ "from PIL import Image\n",
+ "import os\n",
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N8UUFzVRigbC"
+ },
+ "source": [
+ "## Setup\n",
+ "\n",
+ "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "_t28QURYjRQO"
+ },
+ "outputs": [],
+ "source": [
+ "from collections import OrderedDict\n",
+ "\n",
+ "job_to_run = OrderedDict([\n",
+ " ('job', 'extension'),\n",
+ " ('config', OrderedDict([\n",
+ " # this name will be the folder and filename name\n",
+ " ('name', 'my_first_flux_lora_v1'),\n",
+ " ('process', [\n",
+ " OrderedDict([\n",
+ " ('type', 'sd_trainer'),\n",
+ " # root folder to save training sessions/samples/weights\n",
+ " ('training_folder', '/content/output'),\n",
+ " # uncomment to see performance stats in the terminal every N steps\n",
+ " #('performance_log_every', 1000),\n",
+ " ('device', 'cuda:0'),\n",
+ " # if a trigger word is specified, it will be added to captions of training data if it does not already exist\n",
+ " # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word\n",
+ " # ('trigger_word', 'image'),\n",
+ " ('network', OrderedDict([\n",
+ " ('type', 'lora'),\n",
+ " ('linear', 16),\n",
+ " ('linear_alpha', 16)\n",
+ " ])),\n",
+ " ('save', OrderedDict([\n",
+ " ('dtype', 'float16'), # precision to save\n",
+ " ('save_every', 250), # save every this many steps\n",
+ " ('max_step_saves_to_keep', 4) # how many intermittent saves to keep\n",
+ " ])),\n",
+ " ('datasets', [\n",
+ " # datasets are a folder of images. captions need to be txt files with the same name as the image\n",
+ " # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently\n",
+ " # images will automatically be resized and bucketed into the resolution specified\n",
+ " OrderedDict([\n",
+ " ('folder_path', '/content/dataset'),\n",
+ " ('caption_ext', 'txt'),\n",
+ " ('caption_dropout_rate', 0.05), # will drop out the caption 5% of time\n",
+ " ('shuffle_tokens', False), # shuffle caption order, split by commas\n",
+ " ('cache_latents_to_disk', True), # leave this true unless you know what you're doing\n",
+ " ('resolution', [512, 768, 1024]) # flux enjoys multiple resolutions\n",
+ " ])\n",
+ " ]),\n",
+ " ('train', OrderedDict([\n",
+ " ('batch_size', 1),\n",
+ " ('steps', 2000), # total number of steps to train 500 - 4000 is a good range\n",
+ " ('gradient_accumulation_steps', 1),\n",
+ " ('train_unet', True),\n",
+ " ('train_text_encoder', False), # probably won't work with flux\n",
+ " ('gradient_checkpointing', True), # need the on unless you have a ton of vram\n",
+ " ('noise_scheduler', 'flowmatch'), # for training only\n",
+ " ('optimizer', 'adamw8bit'),\n",
+ " ('lr', 1e-4),\n",
+ "\n",
+ " # uncomment this to skip the pre training sample\n",
+ " # ('skip_first_sample', True),\n",
+ "\n",
+ " # uncomment to completely disable sampling\n",
+ " # ('disable_sampling', True),\n",
+ "\n",
+ " # uncomment to use new vell curved weighting. Experimental but may produce better results\n",
+ " # ('linear_timesteps', True),\n",
+ "\n",
+ " # ema will smooth out learning, but could slow it down. Recommended to leave on.\n",
+ " ('ema_config', OrderedDict([\n",
+ " ('use_ema', True),\n",
+ " ('ema_decay', 0.99)\n",
+ " ])),\n",
+ "\n",
+ " # will probably need this if gpu supports it for flux, other dtypes may not work correctly\n",
+ " ('dtype', 'bf16')\n",
+ " ])),\n",
+ " ('model', OrderedDict([\n",
+ " # huggingface model name or path\n",
+ " ('name_or_path', 'black-forest-labs/FLUX.1-schnell'),\n",
+ " ('assistant_lora_path', 'ostris/FLUX.1-schnell-training-adapter'), # Required for flux schnell training\n",
+ " ('is_flux', True),\n",
+ " ('quantize', True), # run 8bit mixed precision\n",
+ " # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary\n",
+ " #('low_vram', True), # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.\n",
+ " ])),\n",
+ " ('sample', OrderedDict([\n",
+ " ('sampler', 'flowmatch'), # must match train.noise_scheduler\n",
+ " ('sample_every', 250), # sample every this many steps\n",
+ " ('width', 1024),\n",
+ " ('height', 1024),\n",
+ " ('prompts', [\n",
+ " # you can add [trigger] to the prompts here and it will be replaced with the trigger word\n",
+ " #'[trigger] holding a sign that says \\'I LOVE PROMPTS!\\'',\n",
+ " 'woman with red hair, playing chess at the park, bomb going off in the background',\n",
+ " 'a woman holding a coffee cup, in a beanie, sitting at a cafe',\n",
+ " 'a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini',\n",
+ " 'a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background',\n",
+ " 'a bear building a log cabin in the snow covered mountains',\n",
+ " 'woman playing the guitar, on stage, singing a song, laser lights, punk rocker',\n",
+ " 'hipster man with a beard, building a chair, in a wood shop',\n",
+ " 'photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop',\n",
+ " 'a man holding a sign that says, \\'this is a sign\\'',\n",
+ " 'a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle'\n",
+ " ]),\n",
+ " ('neg', ''), # not used on flux\n",
+ " ('seed', 42),\n",
+ " ('walk_seed', True),\n",
+ " ('guidance_scale', 1), # schnell does not do guidance\n",
+ " ('sample_steps', 4) # 1 - 4 works well\n",
+ " ]))\n",
+ " ])\n",
+ " ])\n",
+ " ])),\n",
+ " # you can add any additional meta info here. [name] is replaced with config name at top\n",
+ " ('meta', OrderedDict([\n",
+ " ('name', '[name]'),\n",
+ " ('version', '1.0')\n",
+ " ]))\n",
+ "])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h6F1FlM2Wb3l"
+ },
+ "source": [
+ "## Run it\n",
+ "\n",
+ "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. They will be in /content/output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HkajwI8gteOh"
+ },
+ "outputs": [],
+ "source": [
+ "run_job(job_to_run)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Hblgb5uwW5SD"
+ },
+ "source": [
+ "## Done\n",
+ "\n",
+ "Check your ourput dir and get your slider\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "A100",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/notebooks/SliderTraining.ipynb b/notebooks/SliderTraining.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8465ec87dc2d2dce8f11e122c28c80297e3ea2b9
--- /dev/null
+++ b/notebooks/SliderTraining.ipynb
@@ -0,0 +1,339 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm",
+ "gpuType": "V100"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# AI Toolkit by Ostris\n",
+ "## Slider Training\n",
+ "\n",
+ "This is a quick colab demo for training sliders like can be found in my CivitAI profile https://civitai.com/user/Ostris/models . I will work on making it more user friendly, but for now, it will get you started."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!git clone https://github.com/ostris/ai-toolkit"
+ ],
+ "metadata": {
+ "id": "BvAG0GKAh59G"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XGZqVER_aQJW"
+ },
+ "outputs": [],
+ "source": [
+ "!cd ai-toolkit && git submodule update --init --recursive && pip install -r requirements.txt\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "sys.path.append('/content/ai-toolkit')\n",
+ "from toolkit.job import run_job\n",
+ "from collections import OrderedDict\n",
+ "from PIL import Image"
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup\n",
+ "\n",
+ "This is your config. It is documented pretty well. Normally you would do this as a yaml file, but for colab, this will work. This will run as is without modification, but feel free to edit as you want."
+ ],
+ "metadata": {
+ "id": "N8UUFzVRigbC"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from collections import OrderedDict\n",
+ "\n",
+ "job_to_run = OrderedDict({\n",
+ " # This is the config I use on my sliders, It is solid and tested\n",
+ " 'job': 'train',\n",
+ " 'config': {\n",
+ " # the name will be used to create a folder in the output folder\n",
+ " # it will also replace any [name] token in the rest of this config\n",
+ " 'name': 'detail_slider_v1',\n",
+ " # folder will be created with name above in folder below\n",
+ " # it can be relative to the project root or absolute\n",
+ " 'training_folder': \"output/LoRA\",\n",
+ " 'device': 'cuda', # cpu, cuda:0, etc\n",
+ " # for tensorboard logging, we will make a subfolder for this job\n",
+ " 'log_dir': \"output/.tensorboard\",\n",
+ " # you can stack processes for other jobs, It is not tested with sliders though\n",
+ " # just use one for now\n",
+ " 'process': [\n",
+ " {\n",
+ " 'type': 'slider', # tells runner to run the slider process\n",
+ " # network is the LoRA network for a slider, I recommend to leave this be\n",
+ " 'network': {\n",
+ " 'type': \"lora\",\n",
+ " # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good\n",
+ " 'linear': 8, # \"rank\" or \"dim\"\n",
+ " 'linear_alpha': 4, # Do about half of rank \"alpha\"\n",
+ " # 'conv': 4, # for convolutional layers \"locon\"\n",
+ " # 'conv_alpha': 4, # Do about half of conv \"alpha\"\n",
+ " },\n",
+ " # training config\n",
+ " 'train': {\n",
+ " # this is also used in sampling. Stick with ddpm unless you know what you are doing\n",
+ " 'noise_scheduler': \"ddpm\", # or \"ddpm\", \"lms\", \"euler_a\"\n",
+ " # how many steps to train. More is not always better. I rarely go over 1000\n",
+ " 'steps': 100,\n",
+ " # I have had good results with 4e-4 to 1e-4 at 500 steps\n",
+ " 'lr': 2e-4,\n",
+ " # enables gradient checkpoint, saves vram, leave it on\n",
+ " 'gradient_checkpointing': True,\n",
+ " # train the unet. I recommend leaving this true\n",
+ " 'train_unet': True,\n",
+ " # train the text encoder. I don't recommend this unless you have a special use case\n",
+ " # for sliders we are adjusting representation of the concept (unet),\n",
+ " # not the description of it (text encoder)\n",
+ " 'train_text_encoder': False,\n",
+ "\n",
+ " # just leave unless you know what you are doing\n",
+ " # also supports \"dadaptation\" but set lr to 1 if you use that,\n",
+ " # but it learns too fast and I don't recommend it\n",
+ " 'optimizer': \"adamw\",\n",
+ " # only constant for now\n",
+ " 'lr_scheduler': \"constant\",\n",
+ " # we randomly denoise random num of steps form 1 to this number\n",
+ " # while training. Just leave it\n",
+ " 'max_denoising_steps': 40,\n",
+ " # works great at 1. I do 1 even with my 4090.\n",
+ " # higher may not work right with newer single batch stacking code anyway\n",
+ " 'batch_size': 1,\n",
+ " # bf16 works best if your GPU supports it (modern)\n",
+ " 'dtype': 'bf16', # fp32, bf16, fp16\n",
+ " # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX\n",
+ " # although, the way we train sliders is comparative, so it probably won't work anyway\n",
+ " 'noise_offset': 0.0,\n",
+ " },\n",
+ "\n",
+ " # the model to train the LoRA network on\n",
+ " 'model': {\n",
+ " # name_or_path can be a hugging face name, local path or url to model\n",
+ " # on civit ai with or without modelVersionId. They will be cached in /model folder\n",
+ " # epicRealisim v5\n",
+ " 'name_or_path': \"https://civitai.com/models/25694?modelVersionId=134065\",\n",
+ " 'is_v2': False, # for v2 models\n",
+ " 'is_v_pred': False, # for v-prediction models (most v2 models)\n",
+ " # has some issues with the dual text encoder and the way we train sliders\n",
+ " # it works bit weights need to probably be higher to see it.\n",
+ " 'is_xl': False, # for SDXL models\n",
+ " },\n",
+ "\n",
+ " # saving config\n",
+ " 'save': {\n",
+ " 'dtype': 'float16', # precision to save. I recommend float16\n",
+ " 'save_every': 50, # save every this many steps\n",
+ " # this will remove step counts more than this number\n",
+ " # allows you to save more often in case of a crash without filling up your drive\n",
+ " 'max_step_saves_to_keep': 2,\n",
+ " },\n",
+ "\n",
+ " # sampling config\n",
+ " 'sample': {\n",
+ " # must match train.noise_scheduler, this is not used here\n",
+ " # but may be in future and in other processes\n",
+ " 'sampler': \"ddpm\",\n",
+ " # sample every this many steps\n",
+ " 'sample_every': 20,\n",
+ " # image size\n",
+ " 'width': 512,\n",
+ " 'height': 512,\n",
+ " # prompts to use for sampling. Do as many as you want, but it slows down training\n",
+ " # pick ones that will best represent the concept you are trying to adjust\n",
+ " # allows some flags after the prompt\n",
+ " # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive\n",
+ " # slide are good tests. will inherit sample.network_multiplier if not set\n",
+ " # --n [string] # negative prompt, will inherit sample.neg if not set\n",
+ " # Only 75 tokens allowed currently\n",
+ " # I like to do a wide positive and negative spread so I can see a good range and stop\n",
+ " # early if the network is braking down\n",
+ " 'prompts': [\n",
+ " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5\",\n",
+ " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3\",\n",
+ " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3\",\n",
+ " \"a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5\",\n",
+ " \"a golden retriever sitting on a leather couch, --m -5\",\n",
+ " \"a golden retriever sitting on a leather couch --m -3\",\n",
+ " \"a golden retriever sitting on a leather couch --m 3\",\n",
+ " \"a golden retriever sitting on a leather couch --m 5\",\n",
+ " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5\",\n",
+ " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3\",\n",
+ " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3\",\n",
+ " \"a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5\",\n",
+ " ],\n",
+ " # negative prompt used on all prompts above as default if they don't have one\n",
+ " 'neg': \"cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome\",\n",
+ " # seed for sampling. 42 is the answer for everything\n",
+ " 'seed': 42,\n",
+ " # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc\n",
+ " # will start over on next sample_every so s1 is always seed\n",
+ " # works well if you use same prompt but want different results\n",
+ " 'walk_seed': False,\n",
+ " # cfg scale (4 to 10 is good)\n",
+ " 'guidance_scale': 7,\n",
+ " # sampler steps (20 to 30 is good)\n",
+ " 'sample_steps': 20,\n",
+ " # default network multiplier for all prompts\n",
+ " # since we are training a slider, I recommend overriding this with --m [number]\n",
+ " # in the prompts above to get both sides of the slider\n",
+ " 'network_multiplier': 1.0,\n",
+ " },\n",
+ "\n",
+ " # logging information\n",
+ " 'logging': {\n",
+ " 'log_every': 10, # log every this many steps\n",
+ " 'use_wandb': False, # not supported yet\n",
+ " 'verbose': False, # probably done need unless you are debugging\n",
+ " },\n",
+ "\n",
+ " # slider training config, best for last\n",
+ " 'slider': {\n",
+ " # resolutions to train on. [ width, height ]. This is less important for sliders\n",
+ " # as we are not teaching the model anything it doesn't already know\n",
+ " # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1\n",
+ " # and [ 1024, 1024 ] for sd_xl\n",
+ " # you can do as many as you want here\n",
+ " 'resolutions': [\n",
+ " [512, 512],\n",
+ " # [ 512, 768 ]\n",
+ " # [ 768, 768 ]\n",
+ " ],\n",
+ " # slider training uses 4 combined steps for a single round. This will do it in one gradient\n",
+ " # step. It is highly optimized and shouldn't take anymore vram than doing without it,\n",
+ " # since we break down batches for gradient accumulation now. so just leave it on.\n",
+ " 'batch_full_slide': True,\n",
+ " # These are the concepts to train on. You can do as many as you want here,\n",
+ " # but they can conflict outweigh each other. Other than experimenting, I recommend\n",
+ " # just doing one for good results\n",
+ " 'targets': [\n",
+ " # target_class is the base concept we are adjusting the representation of\n",
+ " # for example, if we are adjusting the representation of a person, we would use \"person\"\n",
+ " # if we are adjusting the representation of a cat, we would use \"cat\" It is not\n",
+ " # a keyword necessarily but what the model understands the concept to represent.\n",
+ " # \"person\" will affect men, women, children, etc but will not affect cats, dogs, etc\n",
+ " # it is the models base general understanding of the concept and everything it represents\n",
+ " # you can leave it blank to affect everything. In this example, we are adjusting\n",
+ " # detail, so we will leave it blank to affect everything\n",
+ " {\n",
+ " 'target_class': \"\",\n",
+ " # positive is the prompt for the positive side of the slider.\n",
+ " # It is the concept that will be excited and amplified in the model when we slide the slider\n",
+ " # to the positive side and forgotten / inverted when we slide\n",
+ " # the slider to the negative side. It is generally best to include the target_class in\n",
+ " # the prompt. You want it to be the extreme of what you want to train on. For example,\n",
+ " # if you want to train on fat people, you would use \"an extremely fat, morbidly obese person\"\n",
+ " # as the prompt. Not just \"fat person\"\n",
+ " # max 75 tokens for now\n",
+ " 'positive': \"high detail, 8k, intricate, detailed, high resolution, high res, high quality\",\n",
+ " # negative is the prompt for the negative side of the slider and works the same as positive\n",
+ " # it does not necessarily work the same as a negative prompt when generating images\n",
+ " # these need to be polar opposites.\n",
+ " # max 76 tokens for now\n",
+ " 'negative': \"blurry, boring, fuzzy, low detail, low resolution, low res, low quality\",\n",
+ " # the loss for this target is multiplied by this number.\n",
+ " # if you are doing more than one target it may be good to set less important ones\n",
+ " # to a lower number like 0.1 so they don't outweigh the primary target\n",
+ " 'weight': 1.0,\n",
+ " },\n",
+ " ],\n",
+ " },\n",
+ " },\n",
+ " ]\n",
+ " },\n",
+ "\n",
+ " # You can put any information you want here, and it will be saved in the model.\n",
+ " # The below is an example, but you can put your grocery list in it if you want.\n",
+ " # It is saved in the model so be aware of that. The software will include this\n",
+ " # plus some other information for you automatically\n",
+ " 'meta': {\n",
+ " # [name] gets replaced with the name above\n",
+ " 'name': \"[name]\",\n",
+ " 'version': '1.0',\n",
+ " # 'creator': {\n",
+ " # 'name': 'your name',\n",
+ " # 'email': 'your@gmail.com',\n",
+ " # 'website': 'https://your.website'\n",
+ " # }\n",
+ " }\n",
+ "})\n"
+ ],
+ "metadata": {
+ "id": "_t28QURYjRQO"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Run it\n",
+ "\n",
+ "Below does all the magic. Check your folders to the left. Items will be in output/LoRA/your_name_v1 In the samples folder, there are preiodic sampled. This doesnt work great with colab. Ill update soon."
+ ],
+ "metadata": {
+ "id": "h6F1FlM2Wb3l"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "run_job(job_to_run)\n"
+ ],
+ "metadata": {
+ "id": "HkajwI8gteOh"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Done\n",
+ "\n",
+ "Check your ourput dir and get your slider\n"
+ ],
+ "metadata": {
+ "id": "Hblgb5uwW5SD"
+ }
+ }
+ ]
+}
diff --git a/output/.gitkeep b/output/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1ab621c1f2d4f545808723f8668b7c9276354c9d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+gradio
+huggingface_hub
+python-slugify
+oyaml
+modal
+python-dotenv
\ No newline at end of file
diff --git a/requirements_local.txt b/requirements_local.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f45766b8d5b7eed9c0e0e9bcf9b0f5c0a3cc3588
--- /dev/null
+++ b/requirements_local.txt
@@ -0,0 +1,35 @@
+torch
+torchvision
+safetensors
+git+https://github.com/huggingface/diffusers.git
+transformers
+lycoris-lora==1.8.3
+flatten_json
+pyyaml
+oyaml
+tensorboard
+kornia
+invisible-watermark
+einops
+accelerate
+toml
+albumentations==1.4.15
+albucore==0.0.16
+pydantic
+omegaconf
+k-diffusion
+open_clip_torch
+timm
+prodigyopt
+controlnet_aux==0.0.7
+python-dotenv
+bitsandbytes
+hf_transfer
+lpips
+pytorch_fid
+optimum-quanto==0.2.4
+sentencepiece
+huggingface_hub
+peft
+gradio
+python-slugify
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f13308117a84bc47a054c29d2088eec815572d8
--- /dev/null
+++ b/run.py
@@ -0,0 +1,90 @@
+import os
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+import sys
+from typing import Union, OrderedDict
+from dotenv import load_dotenv
+# Load the .env file if it exists
+load_dotenv()
+
+sys.path.insert(0, os.getcwd())
+# must come before ANY torch or fastai imports
+# import toolkit.cuda_malloc
+
+# turn off diffusers telemetry until I can figure out how to make it opt-in
+os.environ['DISABLE_TELEMETRY'] = 'YES'
+
+# check if we have DEBUG_TOOLKIT in env
+if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
+ # set torch to trace mode
+ import torch
+ torch.autograd.set_detect_anomaly(True)
+import argparse
+from toolkit.job import get_job
+
+
+def print_end_message(jobs_completed, jobs_failed):
+ failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
+ completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
+
+ print("")
+ print("========================================")
+ print("Result:")
+ if len(completed_string) > 0:
+ print(f" - {completed_string}")
+ if len(failure_string) > 0:
+ print(f" - {failure_string}")
+ print("========================================")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ # require at lease one config file
+ parser.add_argument(
+ 'config_file_list',
+ nargs='+',
+ type=str,
+ help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
+ )
+
+ # flag to continue if failed job
+ parser.add_argument(
+ '-r', '--recover',
+ action='store_true',
+ help='Continue running additional jobs even if a job fails'
+ )
+
+ # flag to continue if failed job
+ parser.add_argument(
+ '-n', '--name',
+ type=str,
+ default=None,
+ help='Name to replace [name] tag in config file, useful for shared config file'
+ )
+ args = parser.parse_args()
+
+ config_file_list = args.config_file_list
+ if len(config_file_list) == 0:
+ raise Exception("You must provide at least one config file")
+
+ jobs_completed = 0
+ jobs_failed = 0
+
+ print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
+
+ for config_file in config_file_list:
+ try:
+ job = get_job(config_file, args.name)
+ job.run()
+ job.cleanup()
+ jobs_completed += 1
+ except Exception as e:
+ print(f"Error running job: {e}")
+ jobs_failed += 1
+ if not args.recover:
+ print_end_message(jobs_completed, jobs_failed)
+ raise e
+
+
+if __name__ == '__main__':
+ main()
diff --git a/run_modal.py b/run_modal.py
new file mode 100644
index 0000000000000000000000000000000000000000..4675c1cb8ec709126317dcba02315177df777f68
--- /dev/null
+++ b/run_modal.py
@@ -0,0 +1,175 @@
+'''
+
+ostris/ai-toolkit on https://modal.com
+Run training with the following command:
+modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml
+
+'''
+
+import os
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+import sys
+import modal
+from dotenv import load_dotenv
+# Load the .env file if it exists
+load_dotenv()
+
+sys.path.insert(0, "/root/ai-toolkit")
+# must come before ANY torch or fastai imports
+# import toolkit.cuda_malloc
+
+# turn off diffusers telemetry until I can figure out how to make it opt-in
+os.environ['DISABLE_TELEMETRY'] = 'YES'
+
+# define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes
+# you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models
+model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True)
+
+# modal_output, due to "cannot mount volume on non-empty path" requirement
+MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement
+
+# define modal app
+image = (
+ modal.Image.debian_slim(python_version="3.11")
+ # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app
+ .apt_install("libgl1", "libglib2.0-0")
+ .pip_install(
+ "python-dotenv",
+ "torch",
+ "diffusers[torch]",
+ "transformers",
+ "ftfy",
+ "torchvision",
+ "oyaml",
+ "opencv-python",
+ "albumentations",
+ "safetensors",
+ "lycoris-lora==1.8.3",
+ "flatten_json",
+ "pyyaml",
+ "tensorboard",
+ "kornia",
+ "invisible-watermark",
+ "einops",
+ "accelerate",
+ "toml",
+ "pydantic",
+ "omegaconf",
+ "k-diffusion",
+ "open_clip_torch",
+ "timm",
+ "prodigyopt",
+ "controlnet_aux==0.0.7",
+ "bitsandbytes",
+ "hf_transfer",
+ "lpips",
+ "pytorch_fid",
+ "optimum-quanto",
+ "sentencepiece",
+ "huggingface_hub",
+ "peft"
+ )
+)
+
+# mount for the entire ai-toolkit directory
+# example: "/Users/username/ai-toolkit" is the local directory, "/root/ai-toolkit" is the remote directory
+code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
+
+# create the Modal app with the necessary mounts and volumes
+app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume})
+
+# Check if we have DEBUG_TOOLKIT in env
+if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
+ # Set torch to trace mode
+ import torch
+ torch.autograd.set_detect_anomaly(True)
+
+import argparse
+from toolkit.job import get_job
+
+def print_end_message(jobs_completed, jobs_failed):
+ failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
+ completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
+
+ print("")
+ print("========================================")
+ print("Result:")
+ if len(completed_string) > 0:
+ print(f" - {completed_string}")
+ if len(failure_string) > 0:
+ print(f" - {failure_string}")
+ print("========================================")
+
+
+@app.function(
+ # request a GPU with at least 24GB VRAM
+ # more about modal GPU's: https://modal.com/docs/guide/gpu
+ gpu="A100", # gpu="H100"
+ # more about modal timeouts: https://modal.com/docs/guide/timeouts
+ timeout=7200 # 2 hours, increase or decrease if needed
+)
+def main(config_file_list_str: str, recover: bool = False, name: str = None):
+ # convert the config file list from a string to a list
+ config_file_list = config_file_list_str.split(",")
+
+ jobs_completed = 0
+ jobs_failed = 0
+
+ print(f"Running {len(config_file_list)} job{'' if len(config_file_list) == 1 else 's'}")
+
+ for config_file in config_file_list:
+ try:
+ job = get_job(config_file, name)
+
+ job.config['process'][0]['training_folder'] = MOUNT_DIR
+ os.makedirs(MOUNT_DIR, exist_ok=True)
+ print(f"Training outputs will be saved to: {MOUNT_DIR}")
+
+ # run the job
+ job.run()
+
+ # commit the volume after training
+ model_volume.commit()
+
+ job.cleanup()
+ jobs_completed += 1
+ except Exception as e:
+ print(f"Error running job: {e}")
+ jobs_failed += 1
+ if not recover:
+ print_end_message(jobs_completed, jobs_failed)
+ raise e
+
+ print_end_message(jobs_completed, jobs_failed)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ # require at least one config file
+ parser.add_argument(
+ 'config_file_list',
+ nargs='+',
+ type=str,
+ help='Name of config file (eg: person_v1 for config/person_v1.json/yaml), or full path if it is not in config folder, you can pass multiple config files and run them all sequentially'
+ )
+
+ # flag to continue if a job fails
+ parser.add_argument(
+ '-r', '--recover',
+ action='store_true',
+ help='Continue running additional jobs even if a job fails'
+ )
+
+ # optional name replacement for config file
+ parser.add_argument(
+ '-n', '--name',
+ type=str,
+ default=None,
+ help='Name to replace [name] tag in config file, useful for shared config file'
+ )
+ args = parser.parse_args()
+
+ # convert list of config files to a comma-separated string for Modal compatibility
+ config_file_list_str = ",".join(args.config_file_list)
+
+ main.call(config_file_list_str=config_file_list_str, recover=args.recover, name=args.name)
diff --git a/run_modal_from_hf.py b/run_modal_from_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..f839e5c23f1520024eaa2b46b44bf9a47537dc53
--- /dev/null
+++ b/run_modal_from_hf.py
@@ -0,0 +1,231 @@
+'''
+ostris/ai-toolkit on https://modal.com
+This module provides the Modal app and main function for training FLUX LoRA models.
+The main() function is meant to be called from hf_ui.py, not run directly.
+'''
+
+import os
+os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
+import sys
+import modal
+from dotenv import load_dotenv
+# Load the .env file if it exists
+load_dotenv()
+import yaml
+import traceback
+import zipfile
+
+sys.path.insert(0, "/root/ai-toolkit")
+# must come before ANY torch or fastai imports
+# import toolkit.cuda_malloc
+
+# turn off diffusers telemetry until I can figure out how to make it opt-in
+os.environ['DISABLE_TELEMETRY'] = 'YES'
+# Khai báo secrets
+hf_secret = modal.Secret.from_name("huggingface-secret")
+wandb_secret = modal.Secret.from_name("wandb-secret")
+
+# define the volume for storing model outputs, using "creating volumes lazily": https://modal.com/docs/guide/volumes
+# you will find your model, samples and optimizer stored in: https://modal.com/storage/your-username/main/flux-lora-models
+model_volume = modal.Volume.from_name("flux-lora-models", create_if_missing=True)
+
+# modal_output, due to "cannot mount volume on non-empty path" requirement
+MOUNT_DIR = "/root/ai-toolkit/modal_output" # modal_output, due to "cannot mount volume on non-empty path" requirement
+
+# define modal app
+image = (
+ modal.Image.debian_slim(python_version="3.11")
+ # install required system and pip packages, more about this modal approach: https://modal.com/docs/examples/dreambooth_app
+ .apt_install("libgl1", "libglib2.0-0")
+ .pip_install(
+ "python-dotenv",
+ "torch",
+ "diffusers[torch]",
+ "transformers",
+ "ftfy",
+ "torchvision",
+ "oyaml",
+ "opencv-python",
+ "albumentations",
+ "safetensors",
+ "lycoris-lora==1.8.3",
+ "flatten_json",
+ "pyyaml",
+ "tensorboard",
+ "kornia",
+ "invisible-watermark",
+ "einops",
+ "accelerate",
+ "toml",
+ "pydantic",
+ "omegaconf",
+ "k-diffusion",
+ "open_clip_torch",
+ "timm",
+ "prodigyopt",
+ "controlnet_aux==0.0.7",
+ "bitsandbytes",
+ "hf_transfer",
+ "lpips",
+ "pytorch_fid",
+ "optimum-quanto",
+ "sentencepiece",
+ "huggingface_hub",
+ "peft",
+ "wandb",
+ )
+)
+
+# Mount từ thư mục gốc của HF Space
+code_mount = modal.Mount.from_local_dir(
+ local_dir="/home/user/app", # Đường dẫn mặc định trong HF Space
+ remote_path="/root/ai-toolkit"
+)
+
+# create the Modal app with the necessary mounts and volumes
+app = modal.App(name="flux-lora-training", image=image, mounts=[code_mount], volumes={MOUNT_DIR: model_volume})
+
+# Check if we have DEBUG_TOOLKIT in env
+if os.environ.get("DEBUG_TOOLKIT", "0") == "1":
+ # Set torch to trace mode
+ import torch
+ torch.autograd.set_detect_anomaly(True)
+
+import argparse
+from toolkit.job import get_job
+from toolkit.logging import WandbLogger
+
+def print_end_message(jobs_completed, jobs_failed):
+ failure_string = f"{jobs_failed} failure{'' if jobs_failed == 1 else 's'}" if jobs_failed > 0 else ""
+ completed_string = f"{jobs_completed} completed job{'' if jobs_completed == 1 else 's'}"
+
+ print("")
+ print("========================================")
+ print("Result:")
+ if len(completed_string) > 0:
+ print(f" - {completed_string}")
+ if len(failure_string) > 0:
+ print(f" - {failure_string}")
+ print("========================================")
+
+
+@app.function(
+ # request a GPU with at least 24GB VRAM
+ # more about modal GPU's: https://modal.com/docs/guide/gpu
+ gpu="A100", # gpu="H100"
+ # more about modal timeouts: https://modal.com/docs/guide/timeouts
+ timeout=7200, # 2 hours, increase or decrease if needed
+ secrets=[hf_secret, wandb_secret]
+)
+def main(config_file_list_str: str, recover: bool = False, name: str = None):
+ # Các secrets sẽ tự động được inject vào environment variables
+ # os.environ["HF_TOKEN"] và os.environ["WANDB_API_KEY"]
+
+ # convert the config file list from a string to a list
+ # config_file_list = config_file_list_str.split(",")
+ # convert the config string into a usable dict
+ config = None
+ try:
+ config = yaml.safe_load(config_file_list_str)
+ except Exception as e:
+ print(f"Error loading config file: {e}")
+ traceback.print_exc()
+ raise e
+
+ jobs_completed = 0
+ jobs_failed = 0
+
+ print(f"Running {config['config']['name']}")
+
+ try:
+ # 1. validate config file to make sure required keys are present
+ if 'config' not in config:
+ raise ValueError("config file must have a `config` section")
+ if 'process' not in config['config']:
+ raise ValueError("config file must have a `process` section")
+ if len(config['config']['process']) == 0:
+ raise ValueError("config file must have at least one process")
+ if 'type' not in config['config']['process'][0]:
+ raise ValueError("config file process must have a `type`")
+ if 'training_folder' not in config['config']['process'][0]:
+ raise ValueError("config file process must have a `training_folder`")
+ if not config['config']['process'][0]['training_folder'].startswith("/root/ai-toolkit"):
+ raise ValueError("config file process training_folder path must start with /root/ai-toolkit")
+
+ # find a dataset inside process object
+ datasets = config['config']['process'][0].get('datasets', None)
+ if datasets is not None and isinstance(datasets, list):
+ for dataset in datasets:
+ if 'folder_path' in dataset:
+ if not dataset['folder_path'].startswith('/root/ai-toolkit'):
+ raise ValueError("config file process dataset folder_path must start with /root/ai-toolkit")
+
+ job = get_job(config, name)
+
+ job.config['process'][0]['training_folder'] = MOUNT_DIR
+ os.makedirs(MOUNT_DIR, exist_ok=True)
+ print(f"Training outputs will be saved to: {MOUNT_DIR}")
+
+ # setup wandb
+ if config['config']['process'][0]['logging']['use_wandb']:
+ wandb_token = os.environ.get('WANDB_API_KEY', None)
+ if wandb_token:
+ wandb_logger = WandbLogger(
+ project="flux-lora-training",
+ run_name=name,
+ config=job.raw_config,
+ )
+ job.meta["wandb"] = wandb_logger.run.id
+ job.process[0].logger = wandb_logger
+ else:
+ print("WandB token not found, skipping WandB logging")
+ config['config']['process'][0]['logging']['use_wandb'] = False # disable if no key was given
+
+ # handle dataset zip
+ datasets = config['config']['process'][0].get('datasets', None)
+ if datasets is not None and isinstance(datasets, list):
+ for dataset in datasets:
+ dataset_path = dataset.get('folder_path', None)
+ if dataset_path is not None:
+ # Kiểm tra xem trong folder có zip file không
+ for file in os.listdir(dataset_path):
+ if file.lower().endswith('.zip'):
+ zip_path = os.path.join(dataset_path, file)
+ # Tạo subfolder để extract
+ extract_path = os.path.join(dataset_path, 'extracted')
+ os.makedirs(extract_path, exist_ok=True)
+
+ print(f"Extracting dataset zip file: {zip_path}")
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_path)
+
+ # Cập nhật đường dẫn dataset trong config
+ dataset['folder_path'] = extract_path
+ # Xóa zip file sau khi extract
+ os.remove(zip_path)
+ print(f"Dataset extracted to: {extract_path}")
+ break
+
+ # run the job
+ job.run()
+
+ if config['config']['process'][0]['logging']['use_wandb']:
+ wandb_logger.finish()
+
+ # commit the volume after training
+ model_volume.commit()
+
+ job.cleanup()
+ jobs_completed += 1
+
+ except Exception as e:
+ print(f"Error running job: {e}")
+ if 'response' in e.__dict__:
+ print(f" - Response code: {e.response.status_code} text: {e.response.text}")
+ jobs_failed += 1
+ traceback.print_exc()
+ if not recover:
+ print_end_message(jobs_completed, jobs_failed)
+ raise e
+
+ print_end_message(jobs_completed, jobs_failed)
\ No newline at end of file
diff --git a/scripts/convert_cog.py b/scripts/convert_cog.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba4f6e73c1d3e444583319b37557ad36ec988ccf
--- /dev/null
+++ b/scripts/convert_cog.py
@@ -0,0 +1,128 @@
+import json
+from collections import OrderedDict
+import os
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+
+device = torch.device('cpu')
+
+# [diffusers] -> kohya
+embedding_mapping = {
+ 'text_encoders_0': 'clip_l',
+ 'text_encoders_1': 'clip_g'
+}
+
+PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+KEYMAP_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps')
+sdxl_keymap_path = os.path.join(KEYMAP_ROOT, 'stable_diffusion_locon_sdxl.json')
+
+# load keymap
+with open(sdxl_keymap_path, 'r') as f:
+ ldm_diffusers_keymap = json.load(f)['ldm_diffusers_keymap']
+
+# invert the item / key pairs
+diffusers_ldm_keymap = {v: k for k, v in ldm_diffusers_keymap.items()}
+
+
+def get_ldm_key(diffuser_key):
+ diffuser_key = f"lora_unet_{diffuser_key.replace('.', '_')}"
+ diffuser_key = diffuser_key.replace('_lora_down_weight', '.lora_down.weight')
+ diffuser_key = diffuser_key.replace('_lora_up_weight', '.lora_up.weight')
+ diffuser_key = diffuser_key.replace('_alpha', '.alpha')
+ diffuser_key = diffuser_key.replace('_processor_to_', '_to_')
+ diffuser_key = diffuser_key.replace('_to_out.', '_to_out_0.')
+ if diffuser_key in diffusers_ldm_keymap:
+ return diffusers_ldm_keymap[diffuser_key]
+ else:
+ raise KeyError(f"Key {diffuser_key} not found in keymap")
+
+
+def convert_cog(lora_path, embedding_path):
+ embedding_state_dict = OrderedDict()
+ lora_state_dict = OrderedDict()
+
+ # # normal dict
+ # normal_dict = OrderedDict()
+ # example_path = "/mnt/Models/stable-diffusion/models/LoRA/sdxl/LogoRedmond_LogoRedAF.safetensors"
+ # with safe_open(example_path, framework="pt", device='cpu') as f:
+ # keys = list(f.keys())
+ # for key in keys:
+ # normal_dict[key] = f.get_tensor(key)
+
+ with safe_open(embedding_path, framework="pt", device='cpu') as f:
+ keys = list(f.keys())
+ for key in keys:
+ new_key = embedding_mapping[key]
+ embedding_state_dict[new_key] = f.get_tensor(key)
+
+ with safe_open(lora_path, framework="pt", device='cpu') as f:
+ keys = list(f.keys())
+ lora_rank = None
+
+ # get the lora dim first. Check first 3 linear layers just to be safe
+ for key in keys:
+ new_key = get_ldm_key(key)
+ tensor = f.get_tensor(key)
+ num_checked = 0
+ if len(tensor.shape) == 2:
+ this_dim = min(tensor.shape)
+ if lora_rank is None:
+ lora_rank = this_dim
+ elif lora_rank != this_dim:
+ raise ValueError(f"lora rank is not consistent, got {tensor.shape}")
+ else:
+ num_checked += 1
+ if num_checked >= 3:
+ break
+
+ for key in keys:
+ new_key = get_ldm_key(key)
+ tensor = f.get_tensor(key)
+ if new_key.endswith('.lora_down.weight'):
+ alpha_key = new_key.replace('.lora_down.weight', '.alpha')
+ # diffusers does not have alpha, they usa an alpha multiplier of 1 which is a tensor weight of the dims
+ # assume first smallest dim is the lora rank if shape is 2
+ lora_state_dict[alpha_key] = torch.ones(1).to(tensor.device, tensor.dtype) * lora_rank
+
+ lora_state_dict[new_key] = tensor
+
+ return lora_state_dict, embedding_state_dict
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'lora_path',
+ type=str,
+ help='Path to lora file'
+ )
+ parser.add_argument(
+ 'embedding_path',
+ type=str,
+ help='Path to embedding file'
+ )
+
+ parser.add_argument(
+ '--lora_output',
+ type=str,
+ default="lora_output",
+ )
+
+ parser.add_argument(
+ '--embedding_output',
+ type=str,
+ default="embedding_output",
+ )
+
+ args = parser.parse_args()
+
+ lora_state_dict, embedding_state_dict = convert_cog(args.lora_path, args.embedding_path)
+
+ # save them
+ save_file(lora_state_dict, args.lora_output)
+ save_file(embedding_state_dict, args.embedding_output)
+ print(f"Saved lora to {args.lora_output}")
+ print(f"Saved embedding to {args.embedding_output}")
diff --git a/scripts/convert_lora_to_peft_format.py b/scripts/convert_lora_to_peft_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..3034db646ce0cbf784940df17a45e2468063f485
--- /dev/null
+++ b/scripts/convert_lora_to_peft_format.py
@@ -0,0 +1,91 @@
+# currently only works with flux as support is not quite there yet
+
+import argparse
+import os.path
+from collections import OrderedDict
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ 'input_path',
+ type=str,
+ help='Path to original sdxl model'
+)
+parser.add_argument(
+ 'output_path',
+ type=str,
+ help='output path'
+)
+args = parser.parse_args()
+args.input_path = os.path.abspath(args.input_path)
+args.output_path = os.path.abspath(args.output_path)
+
+from safetensors.torch import load_file, save_file
+
+meta = OrderedDict()
+meta['format'] = 'pt'
+
+state_dict = load_file(args.input_path)
+
+# peft doesnt have an alpha so we need to scale the weights
+alpha_keys = [
+ 'lora_transformer_single_transformer_blocks_0_attn_to_q.alpha' # flux
+]
+
+# keys where the rank is in the first dimension
+rank_idx0_keys = [
+ 'lora_transformer_single_transformer_blocks_0_attn_to_q.lora_down.weight'
+ # 'transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight'
+]
+
+alpha = None
+rank = None
+
+for key in rank_idx0_keys:
+ if key in state_dict:
+ rank = int(state_dict[key].shape[0])
+ break
+
+if rank is None:
+ raise ValueError(f'Could not find rank in state dict')
+
+for key in alpha_keys:
+ if key in state_dict:
+ alpha = int(state_dict[key])
+ break
+
+if alpha is None:
+ # set to rank if not found
+ alpha = rank
+
+
+up_multiplier = alpha / rank
+
+new_state_dict = {}
+
+for key, value in state_dict.items():
+ if key.endswith('.alpha'):
+ continue
+
+ orig_dtype = value.dtype
+
+ new_val = value.float() * up_multiplier
+
+ new_key = key
+ new_key = new_key.replace('lora_transformer_', 'transformer.')
+ for i in range(100):
+ new_key = new_key.replace(f'transformer_blocks_{i}_', f'transformer_blocks.{i}.')
+ new_key = new_key.replace('lora_down', 'lora_A')
+ new_key = new_key.replace('lora_up', 'lora_B')
+ new_key = new_key.replace('_lora', '.lora')
+ new_key = new_key.replace('attn_', 'attn.')
+ new_key = new_key.replace('ff_', 'ff.')
+ new_key = new_key.replace('context_net_', 'context.net.')
+ new_key = new_key.replace('0_proj', '0.proj')
+ new_key = new_key.replace('norm_linear', 'norm.linear')
+ new_key = new_key.replace('norm_out_linear', 'norm_out.linear')
+ new_key = new_key.replace('to_out_', 'to_out.')
+
+ new_state_dict[new_key] = new_val.to(orig_dtype)
+
+save_file(new_state_dict, args.output_path, meta)
+print(f'Saved to {args.output_path}')
diff --git a/scripts/generate_sampler_step_scales.py b/scripts/generate_sampler_step_scales.py
new file mode 100644
index 0000000000000000000000000000000000000000..11efb3183becb48ec4a485565d53049fb6a8d11c
--- /dev/null
+++ b/scripts/generate_sampler_step_scales.py
@@ -0,0 +1,20 @@
+import argparse
+import torch
+import os
+from diffusers import StableDiffusionPipeline
+import sys
+
+PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+# add project root to path
+sys.path.append(PROJECT_ROOT)
+
+SAMPLER_SCALES_ROOT = os.path.join(PROJECT_ROOT, 'toolkit', 'samplers_scales')
+
+
+parser = argparse.ArgumentParser(description='Process some images.')
+add_arg = parser.add_argument
+add_arg('--model', type=str, required=True, help='Path to model')
+add_arg('--sampler', type=str, required=True, help='Name of sampler')
+
+args = parser.parse_args()
+
diff --git a/scripts/make_diffusers_model.py b/scripts/make_diffusers_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4536a9215540dd01321ef1426665db9d6ef6347f
--- /dev/null
+++ b/scripts/make_diffusers_model.py
@@ -0,0 +1,61 @@
+import argparse
+from collections import OrderedDict
+import sys
+import os
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(ROOT_DIR)
+
+import torch
+
+from toolkit.config_modules import ModelConfig
+from toolkit.stable_diffusion_model import StableDiffusion
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ 'input_path',
+ type=str,
+ help='Path to original sdxl model'
+)
+parser.add_argument(
+ 'output_path',
+ type=str,
+ help='output path'
+)
+parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
+parser.add_argument('--refiner', action='store_true', help='is refiner model')
+parser.add_argument('--ssd', action='store_true', help='is ssd model')
+parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
+
+args = parser.parse_args()
+device = torch.device('cpu')
+dtype = torch.float32
+
+print(f"Loading model from {args.input_path}")
+
+
+diffusers_model_config = ModelConfig(
+ name_or_path=args.input_path,
+ is_xl=args.sdxl,
+ is_v2=args.sd2,
+ is_ssd=args.ssd,
+ dtype=dtype,
+ )
+diffusers_sd = StableDiffusion(
+ model_config=diffusers_model_config,
+ device=device,
+ dtype=dtype,
+)
+diffusers_sd.load_model()
+
+
+print(f"Loaded model from {args.input_path}")
+
+diffusers_sd.pipeline.fuse_lora()
+
+meta = OrderedDict()
+
+diffusers_sd.save(args.output_path, meta=meta)
+
+
+print(f"Saved to {args.output_path}")
diff --git a/scripts/make_lcm_sdxl_model.py b/scripts/make_lcm_sdxl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e95ce795a39fe2837b80fcbf1950c256ad4c59
--- /dev/null
+++ b/scripts/make_lcm_sdxl_model.py
@@ -0,0 +1,67 @@
+import argparse
+from collections import OrderedDict
+
+import torch
+
+from toolkit.config_modules import ModelConfig
+from toolkit.stable_diffusion_model import StableDiffusion
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ 'input_path',
+ type=str,
+ help='Path to original sdxl model'
+)
+parser.add_argument(
+ 'output_path',
+ type=str,
+ help='output path'
+)
+parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
+parser.add_argument('--refiner', action='store_true', help='is refiner model')
+parser.add_argument('--ssd', action='store_true', help='is ssd model')
+parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
+
+args = parser.parse_args()
+device = torch.device('cpu')
+dtype = torch.float32
+
+print(f"Loading model from {args.input_path}")
+
+if args.sdxl:
+ adapter_id = "latent-consistency/lcm-lora-sdxl"
+if args.refiner:
+ adapter_id = "latent-consistency/lcm-lora-sdxl"
+elif args.ssd:
+ adapter_id = "latent-consistency/lcm-lora-ssd-1b"
+else:
+ adapter_id = "latent-consistency/lcm-lora-sdv1-5"
+
+
+diffusers_model_config = ModelConfig(
+ name_or_path=args.input_path,
+ is_xl=args.sdxl,
+ is_v2=args.sd2,
+ is_ssd=args.ssd,
+ dtype=dtype,
+ )
+diffusers_sd = StableDiffusion(
+ model_config=diffusers_model_config,
+ device=device,
+ dtype=dtype,
+)
+diffusers_sd.load_model()
+
+
+print(f"Loaded model from {args.input_path}")
+
+diffusers_sd.pipeline.load_lora_weights(adapter_id)
+diffusers_sd.pipeline.fuse_lora()
+
+meta = OrderedDict()
+
+diffusers_sd.save(args.output_path, meta=meta)
+
+
+print(f"Saved to {args.output_path}")
diff --git a/scripts/patch_te_adapter.py b/scripts/patch_te_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7249a46d8e566c3889538c465359e6c66b1c9602
--- /dev/null
+++ b/scripts/patch_te_adapter.py
@@ -0,0 +1,42 @@
+import torch
+from safetensors.torch import save_file, load_file
+from collections import OrderedDict
+meta = OrderedDict()
+meta["format"] ="pt"
+
+attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors")
+state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors")
+
+attn_list = []
+for key, value in state_dict.items():
+ if "attn1" in key:
+ attn_list.append(key)
+
+attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor']
+
+adapter_names = []
+for i in range(100):
+ if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict:
+ adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter")
+
+
+for i in range(len(adapter_names)):
+ adapter_name = adapter_names[i]
+ attn_name = attn_names[i]
+ adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight'
+ adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight'
+ state_k_name = attn_name.replace(".processor", ".to_k.weight")
+ state_v_name = attn_name.replace(".processor", ".to_v.weight")
+ if adapter_k_name in attn_dict:
+ state_dict[state_k_name] = attn_dict[adapter_k_name]
+ state_dict[state_v_name] = attn_dict[adapter_v_name]
+ else:
+ print("adapter_k_name", adapter_k_name)
+ print("state_k_name", state_k_name)
+
+for key, value in state_dict.items():
+ state_dict[key] = value.cpu().to(torch.float16)
+
+save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta)
+
+print("Done")
diff --git a/scripts/repair_dataset_folder.py b/scripts/repair_dataset_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad9d277508c19046b5737620a01b9eba09635e98
--- /dev/null
+++ b/scripts/repair_dataset_folder.py
@@ -0,0 +1,65 @@
+import argparse
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from tqdm import tqdm
+import os
+
+parser = argparse.ArgumentParser(description='Process some images.')
+parser.add_argument("input_folder", type=str, help="Path to folder containing images")
+
+args = parser.parse_args()
+
+img_types = ['.jpg', '.jpeg', '.png', '.webp']
+
+# find all images in the input folder
+images = []
+for root, _, files in os.walk(args.input_folder):
+ for file in files:
+ if file.lower().endswith(tuple(img_types)):
+ images.append(os.path.join(root, file))
+print(f"Found {len(images)} images")
+
+num_skipped = 0
+num_repaired = 0
+num_deleted = 0
+
+pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image")
+for img_path in images:
+ filename = os.path.basename(img_path)
+ filename_no_ext, file_extension = os.path.splitext(filename)
+ # if it is jpg, ignore
+ if file_extension.lower() == '.jpg':
+ num_skipped += 1
+ pbar.update(1)
+
+ continue
+
+ try:
+ img = Image.open(img_path)
+ except Exception as e:
+ print(f"Error opening {img_path}: {e}")
+ # delete it
+ os.remove(img_path)
+ num_deleted += 1
+ pbar.update(1)
+ pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
+ continue
+
+
+ try:
+ img = exif_transpose(img)
+ except Exception as e:
+ print(f"Error rotating {img_path}: {e}")
+
+ new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg')
+
+ img = img.convert("RGB")
+ img.save(new_path, quality=95)
+ # remove the old file
+ os.remove(img_path)
+ num_repaired += 1
+ pbar.update(1)
+ # update pbar
+ pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
+
+print("Done")
\ No newline at end of file
diff --git a/testing/compare_keys.py b/testing/compare_keys.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf4f95203fe1024daeb66ddd79696875f04578c7
--- /dev/null
+++ b/testing/compare_keys.py
@@ -0,0 +1,99 @@
+import argparse
+import os
+
+import torch
+from diffusers.loaders import LoraLoaderMixin
+from safetensors.torch import load_file
+from collections import OrderedDict
+import json
+# this was just used to match the vae keys to the diffusers keys
+# you probably wont need this. Unless they change them.... again... again
+# on second thought, you probably will
+
+device = torch.device('cpu')
+dtype = torch.float32
+
+parser = argparse.ArgumentParser()
+
+# require at lease one config file
+parser.add_argument(
+ 'file_1',
+ nargs='+',
+ type=str,
+ help='Path to first safe tensor file'
+)
+
+parser.add_argument(
+ 'file_2',
+ nargs='+',
+ type=str,
+ help='Path to second safe tensor file'
+)
+
+args = parser.parse_args()
+
+find_matches = False
+
+state_dict_file_1 = load_file(args.file_1[0])
+state_dict_1_keys = list(state_dict_file_1.keys())
+
+state_dict_file_2 = load_file(args.file_2[0])
+state_dict_2_keys = list(state_dict_file_2.keys())
+keys_in_both = []
+
+keys_not_in_state_dict_2 = []
+for key in state_dict_1_keys:
+ if key not in state_dict_2_keys:
+ keys_not_in_state_dict_2.append(key)
+
+keys_not_in_state_dict_1 = []
+for key in state_dict_2_keys:
+ if key not in state_dict_1_keys:
+ keys_not_in_state_dict_1.append(key)
+
+keys_in_both = []
+for key in state_dict_1_keys:
+ if key in state_dict_2_keys:
+ keys_in_both.append(key)
+
+# sort them
+keys_not_in_state_dict_2.sort()
+keys_not_in_state_dict_1.sort()
+keys_in_both.sort()
+
+
+json_data = {
+ "both": keys_in_both,
+ "not_in_state_dict_2": keys_not_in_state_dict_2,
+ "not_in_state_dict_1": keys_not_in_state_dict_1
+}
+json_data = json.dumps(json_data, indent=4)
+
+remaining_diffusers_values = OrderedDict()
+for key in keys_not_in_state_dict_1:
+ remaining_diffusers_values[key] = state_dict_file_2[key]
+
+# print(remaining_diffusers_values.keys())
+
+remaining_ldm_values = OrderedDict()
+for key in keys_not_in_state_dict_2:
+ remaining_ldm_values[key] = state_dict_file_1[key]
+
+# print(json_data)
+
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+json_save_path = os.path.join(project_root, 'config', 'keys.json')
+json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
+json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
+state_dict_1_filename = os.path.basename(args.file_1[0])
+state_dict_2_filename = os.path.basename(args.file_2[0])
+# save key names for each in own file
+with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
+ f.write(json.dumps(state_dict_1_keys, indent=4))
+
+with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
+ f.write(json.dumps(state_dict_2_keys, indent=4))
+
+
+with open(json_save_path, 'w') as f:
+ f.write(json_data)
\ No newline at end of file
diff --git a/testing/generate_lora_mapping.py b/testing/generate_lora_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..e632d2a662f6a6498a8d340074ea9c9a27ac431a
--- /dev/null
+++ b/testing/generate_lora_mapping.py
@@ -0,0 +1,130 @@
+from collections import OrderedDict
+
+import torch
+from safetensors.torch import load_file
+import argparse
+import os
+import json
+
+PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+keymap_path = os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', 'stable_diffusion_sdxl.json')
+
+# load keymap
+with open(keymap_path, 'r') as f:
+ keymap = json.load(f)
+
+lora_keymap = OrderedDict()
+
+# convert keymap to lora key naming
+for ldm_key, diffusers_key in keymap['ldm_diffusers_keymap'].items():
+ if ldm_key.endswith('.bias') or diffusers_key.endswith('.bias'):
+ # skip it
+ continue
+ # sdxl has same te for locon with kohya and ours
+ if ldm_key.startswith('conditioner'):
+ #skip it
+ continue
+ # ignore vae
+ if ldm_key.startswith('first_stage_model'):
+ continue
+ ldm_key = ldm_key.replace('model.diffusion_model.', 'lora_unet_')
+ ldm_key = ldm_key.replace('.weight', '')
+ ldm_key = ldm_key.replace('.', '_')
+
+ diffusers_key = diffusers_key.replace('unet_', 'lora_unet_')
+ diffusers_key = diffusers_key.replace('.weight', '')
+ diffusers_key = diffusers_key.replace('.', '_')
+
+ lora_keymap[f"{ldm_key}.alpha"] = f"{diffusers_key}.alpha"
+ lora_keymap[f"{ldm_key}.lora_down.weight"] = f"{diffusers_key}.lora_down.weight"
+ lora_keymap[f"{ldm_key}.lora_up.weight"] = f"{diffusers_key}.lora_up.weight"
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("input", help="input file")
+parser.add_argument("input2", help="input2 file")
+
+args = parser.parse_args()
+
+# name = args.name
+# if args.sdxl:
+# name += '_sdxl'
+# elif args.sd2:
+# name += '_sd2'
+# else:
+# name += '_sd1'
+name = 'stable_diffusion_locon_sdxl'
+
+locon_save = load_file(args.input)
+our_save = load_file(args.input2)
+
+our_extra_keys = list(set(our_save.keys()) - set(locon_save.keys()))
+locon_extra_keys = list(set(locon_save.keys()) - set(our_save.keys()))
+
+print(f"we have {len(our_extra_keys)} extra keys")
+print(f"locon has {len(locon_extra_keys)} extra keys")
+
+save_dtype = torch.float16
+print(f"our extra keys: {our_extra_keys}")
+print(f"locon extra keys: {locon_extra_keys}")
+
+
+def export_state_dict(our_save):
+ converted_state_dict = OrderedDict()
+ for key, value in our_save.items():
+ # test encoders share keys for some reason
+ if key.startswith('lora_te'):
+ converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
+ else:
+ converted_key = key
+ for ldm_key, diffusers_key in lora_keymap.items():
+ if converted_key == diffusers_key:
+ converted_key = ldm_key
+
+ converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
+ return converted_state_dict
+
+def import_state_dict(loaded_state_dict):
+ converted_state_dict = OrderedDict()
+ for key, value in loaded_state_dict.items():
+ if key.startswith('lora_te'):
+ converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
+ else:
+ converted_key = key
+ for ldm_key, diffusers_key in lora_keymap.items():
+ if converted_key == ldm_key:
+ converted_key = diffusers_key
+
+ converted_state_dict[converted_key] = value.detach().to('cpu', dtype=save_dtype)
+ return converted_state_dict
+
+
+# check it again
+converted_state_dict = export_state_dict(our_save)
+converted_extra_keys = list(set(converted_state_dict.keys()) - set(locon_save.keys()))
+locon_extra_keys = list(set(locon_save.keys()) - set(converted_state_dict.keys()))
+
+
+print(f"we have {len(converted_extra_keys)} extra keys")
+print(f"locon has {len(locon_extra_keys)} extra keys")
+
+print(f"our extra keys: {converted_extra_keys}")
+
+# convert back
+cycle_state_dict = import_state_dict(converted_state_dict)
+cycle_extra_keys = list(set(cycle_state_dict.keys()) - set(our_save.keys()))
+our_extra_keys = list(set(our_save.keys()) - set(cycle_state_dict.keys()))
+
+print(f"we have {len(our_extra_keys)} extra keys")
+print(f"cycle has {len(cycle_extra_keys)} extra keys")
+
+# save keymap
+to_save = OrderedDict()
+to_save['ldm_diffusers_keymap'] = lora_keymap
+
+with open(os.path.join(PROJECT_ROOT, 'toolkit', 'keymaps', f'{name}.json'), 'w') as f:
+ json.dump(to_save, f, indent=4)
+
+
+
diff --git a/testing/generate_weight_mappings.py b/testing/generate_weight_mappings.py
new file mode 100644
index 0000000000000000000000000000000000000000..346fe09d5c98a22a3c06ac9ae1dadb549a196193
--- /dev/null
+++ b/testing/generate_weight_mappings.py
@@ -0,0 +1,479 @@
+import argparse
+import gc
+import os
+import re
+import os
+# add project root to sys path
+import sys
+
+from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import torch
+from diffusers.loaders import LoraLoaderMixin
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+import json
+from tqdm import tqdm
+
+from toolkit.config_modules import ModelConfig
+from toolkit.stable_diffusion_model import StableDiffusion
+
+KEYMAPS_FOLDER = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'toolkit', 'keymaps')
+
+device = torch.device('cpu')
+dtype = torch.float32
+
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def get_reduced_shape(shape_tuple):
+ # iterate though shape anr remove 1s
+ new_shape = []
+ for dim in shape_tuple:
+ if dim != 1:
+ new_shape.append(dim)
+ return tuple(new_shape)
+
+
+parser = argparse.ArgumentParser()
+
+# require at lease one config file
+parser.add_argument(
+ 'file_1',
+ nargs='+',
+ type=str,
+ help='Path to first safe tensor file'
+)
+
+parser.add_argument('--name', type=str, default='stable_diffusion', help='name for mapping to make')
+parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
+parser.add_argument('--refiner', action='store_true', help='is refiner model')
+parser.add_argument('--ssd', action='store_true', help='is ssd model')
+parser.add_argument('--vega', action='store_true', help='is vega model')
+parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
+
+args = parser.parse_args()
+
+file_path = args.file_1[0]
+
+find_matches = False
+
+print(f'Loading diffusers model')
+
+ignore_ldm_begins_with = []
+
+diffusers_file_path = file_path if len(args.file_1) == 1 else args.file_1[1]
+if args.ssd:
+ diffusers_file_path = "segmind/SSD-1B"
+if args.vega:
+ diffusers_file_path = "segmind/Segmind-Vega"
+
+# if args.refiner:
+# diffusers_file_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
+
+if not args.refiner:
+
+ diffusers_model_config = ModelConfig(
+ name_or_path=diffusers_file_path,
+ is_xl=args.sdxl,
+ is_v2=args.sd2,
+ is_ssd=args.ssd,
+ is_vega=args.vega,
+ dtype=dtype,
+ )
+ diffusers_sd = StableDiffusion(
+ model_config=diffusers_model_config,
+ device=device,
+ dtype=dtype,
+ )
+ diffusers_sd.load_model()
+ # delete things we dont need
+ del diffusers_sd.tokenizer
+ flush()
+
+ print(f'Loading ldm model')
+ diffusers_state_dict = diffusers_sd.state_dict()
+else:
+ # refiner wont work directly with stable diffusion
+ # so we need to load the model and then load the state dict
+ diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
+ diffusers_file_path,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+ ).to(device)
+ # diffusers_pipeline = StableDiffusionXLPipeline.from_single_file(
+ # file_path,
+ # torch_dtype=torch.float16,
+ # use_safetensors=True,
+ # variant="fp16",
+ # ).to(device)
+
+ SD_PREFIX_VAE = "vae"
+ SD_PREFIX_UNET = "unet"
+ SD_PREFIX_REFINER_UNET = "refiner_unet"
+ SD_PREFIX_TEXT_ENCODER = "te"
+
+ SD_PREFIX_TEXT_ENCODER1 = "te0"
+ SD_PREFIX_TEXT_ENCODER2 = "te1"
+
+ diffusers_state_dict = OrderedDict()
+ for k, v in diffusers_pipeline.vae.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
+ diffusers_state_dict[new_key] = v
+ for k, v in diffusers_pipeline.text_encoder_2.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
+ diffusers_state_dict[new_key] = v
+ for k, v in diffusers_pipeline.unet.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
+ diffusers_state_dict[new_key] = v
+
+ # add ignore ones as we are only going to focus on unet and copy the rest
+ # ignore_ldm_begins_with = ["conditioner.", "first_stage_model."]
+
+diffusers_dict_keys = list(diffusers_state_dict.keys())
+
+ldm_state_dict = load_file(file_path)
+ldm_dict_keys = list(ldm_state_dict.keys())
+
+ldm_diffusers_keymap = OrderedDict()
+ldm_diffusers_shape_map = OrderedDict()
+ldm_operator_map = OrderedDict()
+diffusers_operator_map = OrderedDict()
+
+total_keys = len(ldm_dict_keys)
+
+matched_ldm_keys = []
+matched_diffusers_keys = []
+
+error_margin = 1e-8
+
+tmp_merge_key = "TMP___MERGE"
+
+te_suffix = ''
+proj_pattern_weight = None
+proj_pattern_bias = None
+text_proj_layer = None
+if args.sdxl or args.ssd or args.vega:
+ te_suffix = '1'
+ ldm_res_block_prefix = "conditioner.embedders.1.model.transformer.resblocks"
+ proj_pattern_weight = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
+ proj_pattern_bias = r"conditioner\.embedders\.1\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
+ text_proj_layer = "conditioner.embedders.1.model.text_projection"
+if args.refiner:
+ te_suffix = '1'
+ ldm_res_block_prefix = "conditioner.embedders.0.model.transformer.resblocks"
+ proj_pattern_weight = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
+ proj_pattern_bias = r"conditioner\.embedders\.0\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
+ text_proj_layer = "conditioner.embedders.0.model.text_projection"
+if args.sd2:
+ te_suffix = ''
+ ldm_res_block_prefix = "cond_stage_model.model.transformer.resblocks"
+ proj_pattern_weight = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_weight"
+ proj_pattern_bias = r"cond_stage_model\.model\.transformer\.resblocks\.(\d+)\.attn\.in_proj_bias"
+ text_proj_layer = "cond_stage_model.model.text_projection"
+
+if args.sdxl or args.sd2 or args.ssd or args.refiner or args.vega:
+ if "conditioner.embedders.1.model.text_projection" in ldm_dict_keys:
+ # d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
+ d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection"].shape[0])
+ elif "conditioner.embedders.1.model.text_projection.weight" in ldm_dict_keys:
+ # d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
+ d_model = int(ldm_state_dict["conditioner.embedders.1.model.text_projection.weight"].shape[0])
+ elif "conditioner.embedders.0.model.text_projection" in ldm_dict_keys:
+ # d_model = int(checkpoint[prefix + "text_projection"].shape[0]))
+ d_model = int(ldm_state_dict["conditioner.embedders.0.model.text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ # do pre known merging
+ for ldm_key in ldm_dict_keys:
+ try:
+ match = re.match(proj_pattern_weight, ldm_key)
+ if match:
+ if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight":
+ print("here")
+ number = int(match.group(1))
+ new_val = torch.cat([
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"],
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"],
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"],
+ ], dim=0)
+ # add to matched so we dont check them
+ matched_diffusers_keys.append(
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight")
+ matched_diffusers_keys.append(
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight")
+ matched_diffusers_keys.append(
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight")
+ # make diffusers convertable_dict
+ diffusers_state_dict[
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.weight"] = new_val
+
+ # add operator
+ ldm_operator_map[ldm_key] = {
+ "cat": [
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight",
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight",
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight",
+ ],
+ }
+
+ matched_ldm_keys.append(ldm_key)
+
+ # text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ # text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model: d_model * 2, :]
+ # text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2:, :]
+
+ # add diffusers operators
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.weight"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
+ f"0:{d_model}, :"
+ ]
+ }
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.weight"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
+ f"{d_model}:{d_model * 2}, :"
+ ]
+ }
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.weight"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_weight",
+ f"{d_model * 2}:, :"
+ ]
+ }
+
+ match = re.match(proj_pattern_bias, ldm_key)
+ if match:
+ number = int(match.group(1))
+ new_val = torch.cat([
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"],
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"],
+ diffusers_state_dict[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"],
+ ], dim=0)
+ # add to matched so we dont check them
+ matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias")
+ matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias")
+ matched_diffusers_keys.append(f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias")
+ # make diffusers convertable_dict
+ diffusers_state_dict[
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.{tmp_merge_key}.bias"] = new_val
+
+ # add operator
+ ldm_operator_map[ldm_key] = {
+ "cat": [
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias",
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias",
+ f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias",
+ ],
+ }
+
+ matched_ldm_keys.append(ldm_key)
+
+ # add diffusers operators
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.q_proj.bias"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
+ f"0:{d_model}, :"
+ ]
+ }
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.k_proj.bias"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
+ f"{d_model}:{d_model * 2}, :"
+ ]
+ }
+ diffusers_operator_map[f"te{te_suffix}_text_model.encoder.layers.{number}.self_attn.v_proj.bias"] = {
+ "slice": [
+ f"{ldm_res_block_prefix}.{number}.attn.in_proj_bias",
+ f"{d_model * 2}:, :"
+ ]
+ }
+ except Exception as e:
+ print(f"Error on key {ldm_key}")
+ print(e)
+
+ # update keys
+ diffusers_dict_keys = list(diffusers_state_dict.keys())
+
+pbar = tqdm(ldm_dict_keys, desc='Matching ldm-diffusers keys', total=total_keys)
+# run through all weights and check mse between them to find matches
+for ldm_key in ldm_dict_keys:
+ ldm_shape_tuple = ldm_state_dict[ldm_key].shape
+ ldm_reduced_shape_tuple = get_reduced_shape(ldm_shape_tuple)
+ for diffusers_key in diffusers_dict_keys:
+ if ldm_key == "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight" and diffusers_key == "te1_text_model.encoder.layers.0.self_attn.q_proj.weight":
+ print("here")
+
+ diffusers_shape_tuple = diffusers_state_dict[diffusers_key].shape
+ diffusers_reduced_shape_tuple = get_reduced_shape(diffusers_shape_tuple)
+
+ # That was easy. Same key
+ # if ldm_key == diffusers_key:
+ # ldm_diffusers_keymap[ldm_key] = diffusers_key
+ # matched_ldm_keys.append(ldm_key)
+ # matched_diffusers_keys.append(diffusers_key)
+ # break
+
+ # if we already have this key mapped, skip it
+ if diffusers_key in matched_diffusers_keys:
+ continue
+
+ # if reduced shapes do not match skip it
+ if ldm_reduced_shape_tuple != diffusers_reduced_shape_tuple:
+ continue
+
+ ldm_weight = ldm_state_dict[ldm_key]
+ did_reduce_ldm = False
+ diffusers_weight = diffusers_state_dict[diffusers_key]
+ did_reduce_diffusers = False
+
+ # reduce the shapes to match if they are not the same
+ if ldm_shape_tuple != ldm_reduced_shape_tuple:
+ ldm_weight = ldm_weight.view(ldm_reduced_shape_tuple)
+ did_reduce_ldm = True
+
+ if diffusers_shape_tuple != diffusers_reduced_shape_tuple:
+ diffusers_weight = diffusers_weight.view(diffusers_reduced_shape_tuple)
+ did_reduce_diffusers = True
+
+ # check to see if they match within a margin of error
+ mse = torch.nn.functional.mse_loss(ldm_weight.float(), diffusers_weight.float())
+ if mse < error_margin:
+ ldm_diffusers_keymap[ldm_key] = diffusers_key
+ matched_ldm_keys.append(ldm_key)
+ matched_diffusers_keys.append(diffusers_key)
+
+ if did_reduce_ldm or did_reduce_diffusers:
+ ldm_diffusers_shape_map[ldm_key] = (ldm_shape_tuple, diffusers_shape_tuple)
+ if did_reduce_ldm:
+ del ldm_weight
+ if did_reduce_diffusers:
+ del diffusers_weight
+ flush()
+
+ break
+
+ pbar.update(1)
+
+pbar.close()
+
+name = args.name
+if args.sdxl:
+ name += '_sdxl'
+elif args.ssd:
+ name += '_ssd'
+elif args.vega:
+ name += '_vega'
+elif args.refiner:
+ name += '_refiner'
+elif args.sd2:
+ name += '_sd2'
+else:
+ name += '_sd1'
+
+# if len(matched_ldm_keys) != len(matched_diffusers_keys):
+unmatched_ldm_keys = [x for x in ldm_dict_keys if x not in matched_ldm_keys]
+unmatched_diffusers_keys = [x for x in diffusers_dict_keys if x not in matched_diffusers_keys]
+# has unmatched keys
+
+has_unmatched_keys = len(unmatched_ldm_keys) > 0 or len(unmatched_diffusers_keys) > 0
+
+
+def get_slices_from_string(s: str) -> tuple:
+ slice_strings = s.split(',')
+ slices = [eval(f"slice({component.strip()})") for component in slice_strings]
+ return tuple(slices)
+
+
+if has_unmatched_keys:
+
+ print(
+ f"Found {len(unmatched_ldm_keys)} unmatched ldm keys and {len(unmatched_diffusers_keys)} unmatched diffusers keys")
+
+ unmatched_obj = OrderedDict()
+ unmatched_obj['ldm'] = OrderedDict()
+ unmatched_obj['diffusers'] = OrderedDict()
+
+ print(f"Gathering info on unmatched keys")
+
+ for key in tqdm(unmatched_ldm_keys, desc='Unmatched LDM keys'):
+ # get min, max, mean, std
+ weight = ldm_state_dict[key]
+ weight_min = weight.min().item()
+ weight_max = weight.max().item()
+ unmatched_obj['ldm'][key] = {
+ 'shape': weight.shape,
+ "min": weight_min,
+ "max": weight_max,
+ }
+ del weight
+ flush()
+
+ for key in tqdm(unmatched_diffusers_keys, desc='Unmatched Diffusers keys'):
+ # get min, max, mean, std
+ weight = diffusers_state_dict[key]
+ weight_min = weight.min().item()
+ weight_max = weight.max().item()
+ unmatched_obj['diffusers'][key] = {
+ "shape": weight.shape,
+ "min": weight_min,
+ "max": weight_max,
+ }
+ del weight
+ flush()
+
+ unmatched_path = os.path.join(KEYMAPS_FOLDER, f'{name}_unmatched.json')
+ with open(unmatched_path, 'w') as f:
+ f.write(json.dumps(unmatched_obj, indent=4))
+
+ print(f'Saved unmatched keys to {unmatched_path}')
+
+# save ldm remainders
+remaining_ldm_values = OrderedDict()
+for key in unmatched_ldm_keys:
+ remaining_ldm_values[key] = ldm_state_dict[key].detach().to('cpu', torch.float16)
+
+save_file(remaining_ldm_values, os.path.join(KEYMAPS_FOLDER, f'{name}_ldm_base.safetensors'))
+print(f'Saved remaining ldm values to {os.path.join(KEYMAPS_FOLDER, f"{name}_ldm_base.safetensors")}')
+
+# do cleanup of some left overs and bugs
+to_remove = []
+for ldm_key, diffusers_key in ldm_diffusers_keymap.items():
+ # get rid of tmp merge keys used to slicing
+ if tmp_merge_key in diffusers_key or tmp_merge_key in ldm_key:
+ to_remove.append(ldm_key)
+
+for key in to_remove:
+ del ldm_diffusers_keymap[key]
+
+to_remove = []
+# remove identical shape mappings. Not sure why they exist but they do
+for ldm_key, shape_list in ldm_diffusers_shape_map.items():
+ # remove identical shape mappings. Not sure why they exist but they do
+ # convert to json string to make it easier to compare
+ ldm_shape = json.dumps(shape_list[0])
+ diffusers_shape = json.dumps(shape_list[1])
+ if ldm_shape == diffusers_shape:
+ to_remove.append(ldm_key)
+
+for key in to_remove:
+ del ldm_diffusers_shape_map[key]
+
+dest_path = os.path.join(KEYMAPS_FOLDER, f'{name}.json')
+save_obj = OrderedDict()
+save_obj["ldm_diffusers_keymap"] = ldm_diffusers_keymap
+save_obj["ldm_diffusers_shape_map"] = ldm_diffusers_shape_map
+save_obj["ldm_diffusers_operator_map"] = ldm_operator_map
+save_obj["diffusers_ldm_operator_map"] = diffusers_operator_map
+with open(dest_path, 'w') as f:
+ f.write(json.dumps(save_obj, indent=4))
+
+print(f'Saved keymap to {dest_path}')
diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a2983c82469a2c6e56874df8b184d30f2d23fc
--- /dev/null
+++ b/testing/merge_in_text_encoder_adapter.py
@@ -0,0 +1,180 @@
+import os
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+import json
+
+# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
+# te_path = "google/flan-t5-xl"
+# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
+# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
+model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch"
+te_path = "google/flan-t5-xl"
+te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors"
+output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1"
+
+
+print("Loading te adapter")
+te_aug_sd = load_file(te_aug_path)
+
+print("Loading model")
+is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path)
+
+# if "pixart" in model_path.lower():
+is_pixart = "pixart" in model_path.lower()
+
+pipeline_class = StableDiffusionPipeline
+
+# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
+
+if is_pixart:
+ pipeline_class = PixArtSigmaPipeline
+
+if is_diffusers:
+ sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
+else:
+ sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
+
+print("Loading Text Encoder")
+# Load the text encoder
+te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16)
+
+# patch it
+sd.text_encoder = te
+sd.tokenizer = T5Tokenizer.from_pretrained(te_path)
+
+if is_pixart:
+ unet = sd.transformer
+ unet_sd = sd.transformer.state_dict()
+else:
+ unet = sd.unet
+ unet_sd = sd.unet.state_dict()
+
+
+if is_pixart:
+ weight_idx = 0
+else:
+ weight_idx = 1
+
+new_cross_attn_dim = None
+
+# count the num of params in state dict
+start_params = sum([v.numel() for v in unet_sd.values()])
+
+print("Building")
+attn_processor_keys = []
+if is_pixart:
+ transformer: Transformer2DModel = unet
+ for i, module in transformer.transformer_blocks.named_children():
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
+ # cross attention
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
+else:
+ attn_processor_keys = list(unet.attn_processors.keys())
+
+for name in attn_processor_keys:
+ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith(
+ "attn1") else \
+ unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config['block_out_channels'][block_id]
+ elif name.startswith("transformer"):
+ hidden_size = unet.config['cross_attention_dim']
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None:
+ pass
+ else:
+ layer_name = name.split(".processor")[0]
+ to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
+ to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
+
+ te_aug_name = None
+ while True:
+ if is_pixart:
+ te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
+ else:
+ te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
+ if f"{te_aug_name}.weight" in te_aug_sd:
+ # increment so we dont redo it next time
+ weight_idx += 1
+ break
+ else:
+ weight_idx += 1
+
+ if weight_idx > 1000:
+ raise ValueError("Could not find the next weight")
+
+ orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape)
+ new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape)
+ orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape)
+ new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape)
+
+ unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
+ unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
+
+ if new_cross_attn_dim is None:
+ new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
+
+
+
+if is_pixart:
+ # copy the caption_projection weight
+ del unet_sd['caption_projection.linear_1.bias']
+ del unet_sd['caption_projection.linear_1.weight']
+ del unet_sd['caption_projection.linear_2.bias']
+ del unet_sd['caption_projection.linear_2.weight']
+
+print("Saving unmodified model")
+sd = sd.to("cpu", torch.float16)
+sd.save_pretrained(
+ output_path,
+ safe_serialization=True,
+)
+
+# overwrite the unet
+if is_pixart:
+ unet_folder = os.path.join(output_path, "transformer")
+else:
+ unet_folder = os.path.join(output_path, "unet")
+
+# move state_dict to cpu
+unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()}
+
+meta = OrderedDict()
+meta["format"] = "pt"
+
+print("Patching")
+
+save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta)
+
+# load the json file
+with open(os.path.join(unet_folder, "config.json"), 'r') as f:
+ config = json.load(f)
+
+config['cross_attention_dim'] = new_cross_attn_dim
+
+if is_pixart:
+ config['caption_channels'] = None
+
+# save it
+with open(os.path.join(unet_folder, "config.json"), 'w') as f:
+ json.dump(config, f, indent=2)
+
+print("Done")
+
+new_params = sum([v.numel() for v in unet_sd.values()])
+
+# print new and old params with , formatted
+print(f"Old params: {start_params:,}")
+print(f"New params: {new_params:,}")
diff --git a/testing/shrink_pixart.py b/testing/shrink_pixart.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad27b1a0ea38612a2a4202261ca88a7875281db1
--- /dev/null
+++ b/testing/shrink_pixart.py
@@ -0,0 +1,62 @@
+import torch
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
+output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
+
+state_dict = load_file(model_path)
+
+meta = OrderedDict()
+meta["format"] = "pt"
+
+new_state_dict = {}
+
+# Move non-blocks over
+for key, value in state_dict.items():
+ if not key.startswith("transformer_blocks."):
+ new_state_dict[key] = value
+
+block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
+ 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
+ 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
+ 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
+ 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
+ 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
+ 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
+ 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
+ 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
+ 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
+ 'transformer_blocks.{idx}.scale_shift_table']
+
+# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
+
+current_idx = 0
+for i in range(28):
+ if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]:
+ # todo merge in with previous block
+ for name in block_names:
+ try:
+ new_state_dict_key = name.format(idx=current_idx - 1)
+ old_state_dict_key = name.format(idx=i)
+ new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
+ except KeyError:
+ raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
+ else:
+ for name in block_names:
+ new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
+ current_idx += 1
+
+
+# make sure they are all fp16 and on cpu
+for key, value in new_state_dict.items():
+ new_state_dict[key] = value.to(torch.float16).cpu()
+
+# save the new state dict
+save_file(new_state_dict, output_path, metadata=meta)
+
+new_param_count = sum([v.numel() for v in new_state_dict.values()])
+old_param_count = sum([v.numel() for v in state_dict.values()])
+
+print(f"Old param count: {old_param_count:,}")
+print(f"New param count: {new_param_count:,}")
\ No newline at end of file
diff --git a/testing/shrink_pixart2.py b/testing/shrink_pixart2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8c30cf87f38610ac23b31afdc94311fba8e3a41
--- /dev/null
+++ b/testing/shrink_pixart2.py
@@ -0,0 +1,81 @@
+import torch
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
+output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
+
+state_dict = load_file(model_path)
+
+meta = OrderedDict()
+meta["format"] = "pt"
+
+new_state_dict = {}
+
+# Move non-blocks over
+for key, value in state_dict.items():
+ if not key.startswith("transformer_blocks."):
+ new_state_dict[key] = value
+
+block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
+ 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
+ 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
+ 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
+ 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
+ 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
+ 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
+ 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
+ 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
+ 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
+ 'transformer_blocks.{idx}.scale_shift_table']
+
+# Blocks to keep
+# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27]
+keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]
+
+
+def weighted_merge(kept_block, removed_block, weight):
+ return kept_block * (1 - weight) + removed_block * weight
+
+
+# First, copy all kept blocks to new_state_dict
+for i, old_idx in enumerate(keep_blocks):
+ for name in block_names:
+ old_key = name.format(idx=old_idx)
+ new_key = name.format(idx=i)
+ new_state_dict[new_key] = state_dict[old_key].clone()
+
+# Then, merge information from removed blocks
+for i in range(28):
+ if i not in keep_blocks:
+ # Find the nearest kept blocks
+ prev_kept = max([b for b in keep_blocks if b < i])
+ next_kept = min([b for b in keep_blocks if b > i])
+
+ # Calculate the weight based on position
+ weight = (i - prev_kept) / (next_kept - prev_kept)
+
+ for name in block_names:
+ removed_key = name.format(idx=i)
+ prev_new_key = name.format(idx=keep_blocks.index(prev_kept))
+ next_new_key = name.format(idx=keep_blocks.index(next_kept))
+
+ # Weighted merge for previous kept block
+ new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight)
+
+ # Weighted merge for next kept block
+ new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key],
+ 1 - weight)
+
+# Convert to fp16 and move to CPU
+for key, value in new_state_dict.items():
+ new_state_dict[key] = value.to(torch.float16).cpu()
+
+# Save the new state dict
+save_file(new_state_dict, output_path, metadata=meta)
+
+new_param_count = sum([v.numel() for v in new_state_dict.values()])
+old_param_count = sum([v.numel() for v in state_dict.values()])
+
+print(f"Old param count: {old_param_count:,}")
+print(f"New param count: {new_param_count:,}")
\ No newline at end of file
diff --git a/testing/shrink_pixart_sm.py b/testing/shrink_pixart_sm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cea07bf154fdd653f1a928f3c553dc56580a828
--- /dev/null
+++ b/testing/shrink_pixart_sm.py
@@ -0,0 +1,84 @@
+import torch
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+meta = OrderedDict()
+meta['format'] = "pt"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def reduce_weight(weight, target_size):
+ weight = weight.to(device, torch.float32)
+ original_shape = weight.shape
+ flattened = weight.view(-1, original_shape[-1])
+
+ if flattened.shape[1] <= target_size:
+ return weight
+
+ U, S, V = torch.svd(flattened)
+ reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size]))
+
+ if reduced.shape[1] < target_size:
+ padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device)
+ reduced = torch.cat((reduced, padding), dim=1)
+
+ return reduced.view(original_shape[:-1] + (target_size,))
+
+
+def reduce_bias(bias, target_size):
+ bias = bias.to(device, torch.float32)
+ original_size = bias.shape[0]
+
+ if original_size <= target_size:
+ return torch.nn.functional.pad(bias, (0, target_size - original_size))
+ else:
+ return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size]
+
+
+# Load your original state dict
+state_dict = load_file(
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
+
+# Create a new state dict for the reduced model
+new_state_dict = {}
+
+source_hidden_size = 1152
+target_hidden_size = 1024
+
+for key, value in state_dict.items():
+ value = value.to(device, torch.float32)
+ if 'weight' in key or 'scale_shift_table' in key:
+ if value.shape[0] == source_hidden_size:
+ value = value[:target_hidden_size]
+ elif value.shape[0] == source_hidden_size * 4:
+ value = value[:target_hidden_size * 4]
+ elif value.shape[0] == source_hidden_size * 6:
+ value = value[:target_hidden_size * 6]
+
+ if len(value.shape) > 1 and value.shape[
+ 1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
+ value = value[:, :target_hidden_size]
+ elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4:
+ value = value[:, :target_hidden_size * 4]
+
+ elif 'bias' in key:
+ if value.shape[0] == source_hidden_size:
+ value = value[:target_hidden_size]
+ elif value.shape[0] == source_hidden_size * 4:
+ value = value[:target_hidden_size * 4]
+ elif value.shape[0] == source_hidden_size * 6:
+ value = value[:target_hidden_size * 6]
+
+ new_state_dict[key] = value
+
+# Move all to CPU and convert to float16
+for key, value in new_state_dict.items():
+ new_state_dict[key] = value.cpu().to(torch.float16)
+
+# Save the new state dict
+save_file(new_state_dict,
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
+ metadata=meta)
+
+print("Done!")
diff --git a/testing/shrink_pixart_sm2.py b/testing/shrink_pixart_sm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd3304dfc72e50e445fce27ae793e7544009aa1a
--- /dev/null
+++ b/testing/shrink_pixart_sm2.py
@@ -0,0 +1,110 @@
+import torch
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+meta = OrderedDict()
+meta['format'] = "pt"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def reduce_weight(weight, target_size):
+ weight = weight.to(device, torch.float32)
+ original_shape = weight.shape
+
+ if len(original_shape) == 1:
+ # For 1D tensors, simply truncate
+ return weight[:target_size]
+
+ if original_shape[0] <= target_size:
+ return weight
+
+ # Reshape the tensor to 2D
+ flattened = weight.reshape(original_shape[0], -1)
+
+ # Perform SVD
+ U, S, V = torch.svd(flattened)
+
+ # Reduce the dimensions
+ reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t())
+
+ # Reshape back to the original shape with reduced first dimension
+ new_shape = (target_size,) + original_shape[1:]
+ return reduced.reshape(new_shape)
+
+
+def reduce_bias(bias, target_size):
+ bias = bias.to(device, torch.float32)
+ return bias[:target_size]
+
+
+# Load your original state dict
+state_dict = load_file(
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
+
+# Create a new state dict for the reduced model
+new_state_dict = {}
+
+for key, value in state_dict.items():
+ value = value.to(device, torch.float32)
+
+ if 'weight' in key or 'scale_shift_table' in key:
+ if value.shape[0] == 1152:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1)
+ # reshape to (1152, -1)
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 512)
+ value = value.view(output_shape)
+ else:
+ # value = reduce_weight(value.t(), 576).t().contiguous()
+ value = reduce_weight(value, 512)
+ pass
+ elif value.shape[0] == 4608:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3])
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 2048)
+ value = value.view(output_shape)
+ else:
+ value = reduce_weight(value, 2048)
+ elif value.shape[0] == 6912:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3])
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 3072)
+ value = value.view(output_shape)
+ else:
+ value = reduce_weight(value, 3072)
+
+ if len(value.shape) > 1 and value.shape[
+ 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
+ value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction
+ pass
+ elif len(value.shape) > 1 and value.shape[1] == 4608:
+ value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction
+ pass
+
+ elif 'bias' in key:
+ if value.shape[0] == 1152:
+ value = reduce_bias(value, 512)
+ elif value.shape[0] == 4608:
+ value = reduce_bias(value, 2048)
+ elif value.shape[0] == 6912:
+ value = reduce_bias(value, 3072)
+
+ new_state_dict[key] = value
+
+# Move all to CPU and convert to float16
+for key, value in new_state_dict.items():
+ new_state_dict[key] = value.cpu().to(torch.float16)
+
+# Save the new state dict
+save_file(new_state_dict,
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
+ metadata=meta)
+
+print("Done!")
\ No newline at end of file
diff --git a/testing/shrink_pixart_sm3.py b/testing/shrink_pixart_sm3.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8756aec45b4a5cb59315ab11a5bed320d74f7ba
--- /dev/null
+++ b/testing/shrink_pixart_sm3.py
@@ -0,0 +1,100 @@
+import torch
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+meta = OrderedDict()
+meta['format'] = "pt"
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def reduce_weight(weight, target_size):
+ weight = weight.to(device, torch.float32)
+ # resize so target_size is the first dimension
+ tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1])
+
+ # use interpolate to resize the tensor
+ new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True)
+
+ # reshape back to original shape
+ return new_weight.view(target_size, weight.shape[1])
+
+
+def reduce_bias(bias, target_size):
+ bias = bias.view(1, 1, bias.shape[0], 1)
+
+ new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True)
+
+ return new_bias.view(target_size)
+
+
+# Load your original state dict
+state_dict = load_file(
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors")
+
+# Create a new state dict for the reduced model
+new_state_dict = {}
+
+for key, value in state_dict.items():
+ value = value.to(device, torch.float32)
+
+ if 'weight' in key or 'scale_shift_table' in key:
+ if value.shape[0] == 1152:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1)
+ # reshape to (1152, -1)
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 512)
+ value = value.view(output_shape)
+ else:
+ # value = reduce_weight(value.t(), 576).t().contiguous()
+ value = reduce_weight(value, 512)
+ pass
+ elif value.shape[0] == 4608:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3])
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 2048)
+ value = value.view(output_shape)
+ else:
+ value = reduce_weight(value, 2048)
+ elif value.shape[0] == 6912:
+ if len(value.shape) == 4:
+ orig_shape = value.shape
+ output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3])
+ value = value.view(value.shape[0], -1)
+ value = reduce_weight(value, 3072)
+ value = value.view(output_shape)
+ else:
+ value = reduce_weight(value, 3072)
+
+ if len(value.shape) > 1 and value.shape[
+ 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key:
+ value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction
+ pass
+ elif len(value.shape) > 1 and value.shape[1] == 4608:
+ value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction
+ pass
+
+ elif 'bias' in key:
+ if value.shape[0] == 1152:
+ value = reduce_bias(value, 512)
+ elif value.shape[0] == 4608:
+ value = reduce_bias(value, 2048)
+ elif value.shape[0] == 6912:
+ value = reduce_bias(value, 3072)
+
+ new_state_dict[key] = value
+
+# Move all to CPU and convert to float16
+for key, value in new_state_dict.items():
+ new_state_dict[key] = value.cpu().to(torch.float16)
+
+# Save the new state dict
+save_file(new_state_dict,
+ "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors",
+ metadata=meta)
+
+print("Done!")
\ No newline at end of file
diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d97f2d92949de05dbc831cbdbdb764a5997dca
--- /dev/null
+++ b/testing/test_bucket_dataloader.py
@@ -0,0 +1,128 @@
+import time
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import sys
+import os
+import cv2
+import random
+from transformers import CLIPImageProcessor
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from toolkit.paths import SD_SCRIPTS_ROOT
+import torchvision.transforms.functional
+from toolkit.image_utils import show_img, show_tensors
+
+sys.path.append(SD_SCRIPTS_ROOT)
+
+from library.model_util import load_vae
+from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
+from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
+ trigger_dataloader_setup_epoch
+from toolkit.config_modules import DatasetConfig
+import argparse
+from tqdm import tqdm
+
+parser = argparse.ArgumentParser()
+parser.add_argument('dataset_folder', type=str, default='input')
+parser.add_argument('--epochs', type=int, default=1)
+
+
+
+args = parser.parse_args()
+
+dataset_folder = args.dataset_folder
+resolution = 1024
+bucket_tolerance = 64
+batch_size = 1
+
+clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
+
+class FakeAdapter:
+ def __init__(self):
+ self.clip_image_processor = clip_processor
+
+
+## make fake sd
+class FakeSD:
+ def __init__(self):
+ self.adapter = FakeAdapter()
+
+
+
+
+dataset_config = DatasetConfig(
+ dataset_path=dataset_folder,
+ # clip_image_path=dataset_folder,
+ # square_crop=True,
+ resolution=resolution,
+ # caption_ext='json',
+ default_caption='default',
+ # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
+ buckets=True,
+ bucket_tolerance=bucket_tolerance,
+ # poi='person',
+ # shuffle_augmentations=True,
+ # augmentations=[
+ # {
+ # 'method': 'Posterize',
+ # 'num_bits': [(0, 4), (0, 4), (0, 4)],
+ # 'p': 1.0
+ # },
+ #
+ # ]
+)
+
+dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
+
+
+# run through an epoch ang check sizes
+dataloader_iterator = iter(dataloader)
+for epoch in range(args.epochs):
+ for batch in tqdm(dataloader):
+ batch: 'DataLoaderBatchDTO'
+ img_batch = batch.tensor
+ batch_size, channels, height, width = img_batch.shape
+
+ # img_batch = color_block_imgs(img_batch, neg1_1=True)
+
+ # chunks = torch.chunk(img_batch, batch_size, dim=0)
+ # # put them so they are size by side
+ # big_img = torch.cat(chunks, dim=3)
+ # big_img = big_img.squeeze(0)
+ #
+ # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
+ # big_control_img = torch.cat(control_chunks, dim=3)
+ # big_control_img = big_control_img.squeeze(0) * 2 - 1
+ #
+ #
+ # # resize control image
+ # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
+ #
+ # big_img = torch.cat([big_img, big_control_img], dim=2)
+ #
+ # min_val = big_img.min()
+ # max_val = big_img.max()
+ #
+ # big_img = (big_img / 2 + 0.5).clamp(0, 1)
+
+ big_img = img_batch
+ # big_img = big_img.clamp(-1, 1)
+
+ show_tensors(big_img)
+
+ # convert to image
+ # img = transforms.ToPILImage()(big_img)
+ #
+ # show_img(img)
+
+ time.sleep(0.2)
+ # if not last epoch
+ if epoch < args.epochs - 1:
+ trigger_dataloader_setup_epoch(dataloader)
+
+cv2.destroyAllWindows()
+
+print('done')
diff --git a/testing/test_model_load_save.py b/testing/test_model_load_save.py
new file mode 100644
index 0000000000000000000000000000000000000000..87bdfb3ef8246268f0660db6bf24822c74506c45
--- /dev/null
+++ b/testing/test_model_load_save.py
@@ -0,0 +1,172 @@
+import argparse
+import os
+# add project root to sys path
+import sys
+
+from tqdm import tqdm
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import torch
+from diffusers.loaders import LoraLoaderMixin
+from safetensors.torch import load_file
+from collections import OrderedDict
+import json
+
+from toolkit.config_modules import ModelConfig
+from toolkit.paths import KEYMAPS_ROOT
+from toolkit.saving import convert_state_dict_to_ldm_with_mapping, get_ldm_state_dict_from_diffusers
+from toolkit.stable_diffusion_model import StableDiffusion
+
+# this was just used to match the vae keys to the diffusers keys
+# you probably wont need this. Unless they change them.... again... again
+# on second thought, you probably will
+
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+device = torch.device('cpu')
+dtype = torch.float32
+
+parser = argparse.ArgumentParser()
+
+# require at lease one config file
+parser.add_argument(
+ 'file_1',
+ nargs='+',
+ type=str,
+ help='Path an LDM model'
+)
+
+parser.add_argument(
+ '--is_xl',
+ action='store_true',
+ help='Is the model an XL model'
+)
+
+parser.add_argument(
+ '--is_v2',
+ action='store_true',
+ help='Is the model a v2 model'
+)
+
+args = parser.parse_args()
+
+find_matches = False
+
+print("Loading model")
+state_dict_file_1 = load_file(args.file_1[0])
+state_dict_1_keys = list(state_dict_file_1.keys())
+
+print("Loading model into diffusers format")
+model_config = ModelConfig(
+ name_or_path=args.file_1[0],
+ is_xl=args.is_xl
+)
+sd = StableDiffusion(
+ model_config=model_config,
+ device=device,
+)
+sd.load_model()
+
+# load our base
+base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
+mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
+
+print("Converting model back to LDM")
+version_string = '1'
+if args.is_v2:
+ version_string = '2'
+if args.is_xl:
+ version_string = 'sdxl'
+# convert the state dict
+state_dict_file_2 = get_ldm_state_dict_from_diffusers(
+ sd.state_dict(),
+ version_string,
+ device='cpu',
+ dtype=dtype
+)
+
+# state_dict_file_2 = load_file(args.file_2[0])
+
+state_dict_2_keys = list(state_dict_file_2.keys())
+keys_in_both = []
+
+keys_not_in_state_dict_2 = []
+for key in state_dict_1_keys:
+ if key not in state_dict_2_keys:
+ keys_not_in_state_dict_2.append(key)
+
+keys_not_in_state_dict_1 = []
+for key in state_dict_2_keys:
+ if key not in state_dict_1_keys:
+ keys_not_in_state_dict_1.append(key)
+
+keys_in_both = []
+for key in state_dict_1_keys:
+ if key in state_dict_2_keys:
+ keys_in_both.append(key)
+
+# sort them
+keys_not_in_state_dict_2.sort()
+keys_not_in_state_dict_1.sort()
+keys_in_both.sort()
+
+if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0:
+ print("All keys match!")
+ print("Checking values...")
+ mismatch_keys = []
+ loss = torch.nn.MSELoss()
+ tolerance = 1e-6
+ for key in tqdm(keys_in_both):
+ if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance:
+ print(f"Values for key {key} don't match!")
+ print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}")
+ mismatch_keys.append(key)
+
+ if len(mismatch_keys) == 0:
+ print("All values match!")
+ else:
+ print("Some valued font match!")
+ print(mismatch_keys)
+ mismatched_path = os.path.join(project_root, 'config', 'mismatch.json')
+ with open(mismatched_path, 'w') as f:
+ f.write(json.dumps(mismatch_keys, indent=4))
+ exit(0)
+
+else:
+ print("Keys don't match!, generating info...")
+
+json_data = {
+ "both": keys_in_both,
+ "not_in_state_dict_2": keys_not_in_state_dict_2,
+ "not_in_state_dict_1": keys_not_in_state_dict_1
+}
+json_data = json.dumps(json_data, indent=4)
+
+remaining_diffusers_values = OrderedDict()
+for key in keys_not_in_state_dict_1:
+ remaining_diffusers_values[key] = state_dict_file_2[key]
+
+# print(remaining_diffusers_values.keys())
+
+remaining_ldm_values = OrderedDict()
+for key in keys_not_in_state_dict_2:
+ remaining_ldm_values[key] = state_dict_file_1[key]
+
+# print(json_data)
+
+
+json_save_path = os.path.join(project_root, 'config', 'keys.json')
+json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
+json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
+state_dict_1_filename = os.path.basename(args.file_1[0])
+# state_dict_2_filename = os.path.basename(args.file_2[0])
+# save key names for each in own file
+with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
+ f.write(json.dumps(state_dict_1_keys, indent=4))
+
+with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f:
+ f.write(json.dumps(state_dict_2_keys, indent=4))
+
+with open(json_save_path, 'w') as f:
+ f.write(json_data)
diff --git a/testing/test_vae.py b/testing/test_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b31f6311024833b1dc4c46cf431f7ae68f09d7
--- /dev/null
+++ b/testing/test_vae.py
@@ -0,0 +1,113 @@
+import argparse
+import os
+from PIL import Image
+import torch
+from torchvision.transforms import Resize, ToTensor
+from diffusers import AutoencoderKL
+from pytorch_fid import fid_score
+from skimage.metrics import peak_signal_noise_ratio as psnr
+import lpips
+from tqdm import tqdm
+from torchvision import transforms
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def load_images(folder_path):
+ images = []
+ for filename in os.listdir(folder_path):
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
+ img_path = os.path.join(folder_path, filename)
+ images.append(img_path)
+ return images
+
+
+def paramiter_count(model):
+ state_dict = model.state_dict()
+ paramiter_count = 0
+ for key in state_dict:
+ paramiter_count += torch.numel(state_dict[key])
+ return int(paramiter_count)
+
+
+def calculate_metrics(vae, images, max_imgs=-1):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ vae = vae.to(device)
+ lpips_model = lpips.LPIPS(net='alex').to(device)
+
+ rfid_scores = []
+ psnr_scores = []
+ lpips_scores = []
+
+ # transform = transforms.Compose([
+ # transforms.Resize(256, antialias=True),
+ # transforms.CenterCrop(256)
+ # ])
+ # needs values between -1 and 1
+ to_tensor = ToTensor()
+
+ if max_imgs > 0 and len(images) > max_imgs:
+ images = images[:max_imgs]
+
+ for img_path in tqdm(images):
+ try:
+ img = Image.open(img_path).convert('RGB')
+ # img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
+ img_tensor = to_tensor(img).unsqueeze(0).to(device)
+ img_tensor = 2 * img_tensor - 1
+ # if width or height is not divisible by 8, crop it
+ if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
+ img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]
+
+ except Exception as e:
+ print(f"Error processing {img_path}: {e}")
+ continue
+
+
+ with torch.no_grad():
+ reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample
+
+ # Calculate rFID
+ # rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
+ # rfid_scores.append(rfid)
+
+ # Calculate PSNR
+ psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
+ psnr_scores.append(psnr_val)
+
+ # Calculate LPIPS
+ lpips_val = lpips_model(img_tensor, reconstructed).item()
+ lpips_scores.append(lpips_val)
+
+ # avg_rfid = sum(rfid_scores) / len(rfid_scores)
+ avg_rfid = 0
+ avg_psnr = sum(psnr_scores) / len(psnr_scores)
+ avg_lpips = sum(lpips_scores) / len(lpips_scores)
+
+ return avg_rfid, avg_psnr, avg_lpips
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
+ parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
+ parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
+ parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
+ args = parser.parse_args()
+
+ if os.path.isfile(args.vae_path):
+ vae = AutoencoderKL.from_single_file(args.vae_path)
+ else:
+ vae = AutoencoderKL.from_pretrained(args.vae_path)
+ vae.eval()
+ vae = vae.to(device)
+ print(f"Model has {paramiter_count(vae)} parameters")
+ images = load_images(args.image_folder)
+
+ avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
+
+ # print(f"Average rFID: {avg_rfid}")
+ print(f"Average PSNR: {avg_psnr}")
+ print(f"Average LPIPS: {avg_lpips}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/testing/test_vae_cycle.py b/testing/test_vae_cycle.py
new file mode 100644
index 0000000000000000000000000000000000000000..175e8f8fa5cdb4cb652225f4d95e7a2cbb04fd29
--- /dev/null
+++ b/testing/test_vae_cycle.py
@@ -0,0 +1,112 @@
+import os
+
+import torch
+from safetensors.torch import load_file
+from collections import OrderedDict
+from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers
+import json
+# this was just used to match the vae keys to the diffusers keys
+# you probably wont need this. Unless they change them.... again... again
+# on second thought, you probably will
+
+device = torch.device('cpu')
+dtype = torch.float32
+vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors'
+
+find_matches = False
+
+state_dict_ldm = load_file(vae_path)
+diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device)
+
+ldm_keys = state_dict_ldm.keys()
+
+matched_keys = {}
+duplicated_keys = {
+
+}
+
+if find_matches:
+ # find values that match with a very low mse
+ for ldm_key in ldm_keys:
+ ldm_value = state_dict_ldm[ldm_key]
+ for diffusers_key in list(diffusers_vae.state_dict().keys()):
+ diffusers_value = diffusers_vae.state_dict()[diffusers_key]
+ if diffusers_key in vae_keys_squished_on_diffusers:
+ diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1)
+ # if they are not same shape, skip
+ if ldm_value.shape != diffusers_value.shape:
+ continue
+ mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value)
+ if mse < 1e-6:
+ if ldm_key in list(matched_keys.keys()):
+ print(f'{ldm_key} already matched to {matched_keys[ldm_key]}')
+ if ldm_key in duplicated_keys:
+ duplicated_keys[ldm_key].append(diffusers_key)
+ else:
+ duplicated_keys[ldm_key] = [diffusers_key]
+ continue
+ matched_keys[ldm_key] = diffusers_key
+ is_matched = True
+ break
+
+ print(f'Found {len(matched_keys)} matches')
+
+dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae)
+dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys())
+keys_in_both = []
+
+keys_not_in_diffusers = []
+for key in ldm_keys:
+ if key not in dif_to_ldm_state_dict_keys:
+ keys_not_in_diffusers.append(key)
+
+keys_not_in_ldm = []
+for key in dif_to_ldm_state_dict_keys:
+ if key not in ldm_keys:
+ keys_not_in_ldm.append(key)
+
+keys_in_both = []
+for key in ldm_keys:
+ if key in dif_to_ldm_state_dict_keys:
+ keys_in_both.append(key)
+
+# sort them
+keys_not_in_diffusers.sort()
+keys_not_in_ldm.sort()
+keys_in_both.sort()
+
+# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}')
+# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}')
+# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}')
+
+json_data = {
+ "both": keys_in_both,
+ "ldm": keys_not_in_diffusers,
+ "diffusers": keys_not_in_ldm
+}
+json_data = json.dumps(json_data, indent=4)
+
+remaining_diffusers_values = OrderedDict()
+for key in keys_not_in_ldm:
+ remaining_diffusers_values[key] = dif_to_ldm_state_dict[key]
+
+# print(remaining_diffusers_values.keys())
+
+remaining_ldm_values = OrderedDict()
+for key in keys_not_in_diffusers:
+ remaining_ldm_values[key] = state_dict_ldm[key]
+
+# print(json_data)
+
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+json_save_path = os.path.join(project_root, 'config', 'keys.json')
+json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
+json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
+
+with open(json_save_path, 'w') as f:
+ f.write(json_data)
+if find_matches:
+ with open(json_matched_save_path, 'w') as f:
+ f.write(json.dumps(matched_keys, indent=4))
+ with open(json_duped_save_path, 'w') as f:
+ f.write(json.dumps(duplicated_keys, indent=4))
diff --git a/toolkit/__init__.py b/toolkit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/toolkit/assistant_lora.py b/toolkit/assistant_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdeca968ad6c6d5b403f3786eea81efc33944c94
--- /dev/null
+++ b/toolkit/assistant_lora.py
@@ -0,0 +1,55 @@
+from typing import TYPE_CHECKING
+from toolkit.config_modules import NetworkConfig
+from toolkit.lora_special import LoRASpecialNetwork
+from safetensors.torch import load_file
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork:
+ if not sd.is_flux:
+ raise ValueError("Only Flux models can load assistant adapters currently.")
+ pipe = sd.pipeline
+ print(f"Loading assistant adapter from {adapter_path}")
+ adapter_name = adapter_path.split("/")[-1].split(".")[0]
+ lora_state_dict = load_file(adapter_path)
+
+ linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0])
+ # linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item())
+ linear_alpha = linear_dim
+ transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict
+ # get dim and scale
+ network_config = NetworkConfig(
+ linear=linear_dim,
+ linear_alpha=linear_alpha,
+ transformer_only=transformer_only,
+ )
+
+ network = LoRASpecialNetwork(
+ text_encoder=pipe.text_encoder,
+ unet=pipe.transformer,
+ lora_dim=network_config.linear,
+ multiplier=1.0,
+ alpha=network_config.linear_alpha,
+ train_unet=True,
+ train_text_encoder=False,
+ is_flux=True,
+ network_config=network_config,
+ network_type=network_config.type,
+ transformer_only=network_config.transformer_only,
+ is_assistant_adapter=True
+ )
+ network.apply_to(
+ pipe.text_encoder,
+ pipe.transformer,
+ apply_text_encoder=False,
+ apply_unet=True
+ )
+ network.force_to(sd.device_torch, dtype=sd.torch_dtype)
+ network.eval()
+ network._update_torch_multiplier()
+ network.load_weights(lora_state_dict)
+ network.is_active = True
+
+ return network
diff --git a/toolkit/basic.py b/toolkit/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d32a9d2f356bbda1d3e629c6ebc6688b5f1d458
--- /dev/null
+++ b/toolkit/basic.py
@@ -0,0 +1,56 @@
+import gc
+
+import torch
+
+
+def value_map(inputs, min_in, max_in, min_out, max_out):
+ return (inputs - min_in) * (max_out - min_out) / (max_in - min_in) + min_out
+
+
+def flush(garbage_collect=True):
+ torch.cuda.empty_cache()
+ if garbage_collect:
+ gc.collect()
+
+
+def get_mean_std(tensor):
+ if len(tensor.shape) == 3:
+ tensor = tensor.unsqueeze(0)
+ elif len(tensor.shape) != 4:
+ raise Exception("Expected tensor of shape (batch_size, channels, width, height)")
+ mean, variance = torch.mean(
+ tensor, dim=[2, 3], keepdim=True
+ ), torch.var(
+ tensor, dim=[2, 3],
+ keepdim=True
+ )
+ std = torch.sqrt(variance + 1e-5)
+ return mean, std
+
+
+def adain(content_features, style_features):
+ # Assumes that the content and style features are of shape (batch_size, channels, width, height)
+
+ dims = [2, 3]
+ if len(content_features.shape) == 3:
+ # content_features = content_features.unsqueeze(0)
+ # style_features = style_features.unsqueeze(0)
+ dims = [1]
+
+ # Step 1: Calculate mean and variance of content features
+ content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features,
+ dim=dims,
+ keepdim=True)
+ # Step 2: Calculate mean and variance of style features
+ style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims,
+ keepdim=True)
+
+ # Step 3: Normalize content features
+ content_std = torch.sqrt(content_var + 1e-5)
+ normalized_content = (content_features - content_mean) / content_std
+
+ # Step 4: Scale and shift normalized content with style's statistics
+ style_std = torch.sqrt(style_var + 1e-5)
+ stylized_content = normalized_content * style_std + style_mean
+
+ return stylized_content
diff --git a/toolkit/buckets.py b/toolkit/buckets.py
new file mode 100644
index 0000000000000000000000000000000000000000..835c9eb96f5bb38bb6d871530b89cb835dc47091
--- /dev/null
+++ b/toolkit/buckets.py
@@ -0,0 +1,174 @@
+from typing import Type, List, Union, TypedDict
+
+
+class BucketResolution(TypedDict):
+ width: int
+ height: int
+
+
+# resolutions SDXL was trained on with a 1024x1024 base resolution
+resolutions_1024: List[BucketResolution] = [
+ # SDXL Base resolution
+ {"width": 1024, "height": 1024},
+ # SDXL Resolutions, widescreen
+ {"width": 2048, "height": 512},
+ {"width": 1984, "height": 512},
+ {"width": 1920, "height": 512},
+ {"width": 1856, "height": 512},
+ {"width": 1792, "height": 576},
+ {"width": 1728, "height": 576},
+ {"width": 1664, "height": 576},
+ {"width": 1600, "height": 640},
+ {"width": 1536, "height": 640},
+ {"width": 1472, "height": 704},
+ {"width": 1408, "height": 704},
+ {"width": 1344, "height": 704},
+ {"width": 1344, "height": 768},
+ {"width": 1280, "height": 768},
+ {"width": 1216, "height": 832},
+ {"width": 1152, "height": 832},
+ {"width": 1152, "height": 896},
+ {"width": 1088, "height": 896},
+ {"width": 1088, "height": 960},
+ {"width": 1024, "height": 960},
+ # SDXL Resolutions, portrait
+ {"width": 960, "height": 1024},
+ {"width": 960, "height": 1088},
+ {"width": 896, "height": 1088},
+ {"width": 896, "height": 1152}, # 2:3
+ {"width": 832, "height": 1152},
+ {"width": 832, "height": 1216},
+ {"width": 768, "height": 1280},
+ {"width": 768, "height": 1344},
+ {"width": 704, "height": 1408},
+ {"width": 704, "height": 1472},
+ {"width": 640, "height": 1536},
+ {"width": 640, "height": 1600},
+ {"width": 576, "height": 1664},
+ {"width": 576, "height": 1728},
+ {"width": 576, "height": 1792},
+ {"width": 512, "height": 1856},
+ {"width": 512, "height": 1920},
+ {"width": 512, "height": 1984},
+ {"width": 512, "height": 2048},
+ # extra wides
+ {"width": 8192, "height": 128},
+ {"width": 128, "height": 8192},
+]
+
+# Even numbers so they can be patched easier
+resolutions_dit_1024: List[BucketResolution] = [
+ # Base resolution
+ {"width": 1024, "height": 1024},
+ # widescreen
+ {"width": 2048, "height": 512},
+ {"width": 1792, "height": 576},
+ {"width": 1728, "height": 576},
+ {"width": 1664, "height": 576},
+ {"width": 1600, "height": 640},
+ {"width": 1536, "height": 640},
+ {"width": 1472, "height": 704},
+ {"width": 1408, "height": 704},
+ {"width": 1344, "height": 704},
+ {"width": 1344, "height": 768},
+ {"width": 1280, "height": 768},
+ {"width": 1216, "height": 832},
+ {"width": 1152, "height": 832},
+ {"width": 1152, "height": 896},
+ {"width": 1088, "height": 896},
+ {"width": 1088, "height": 960},
+ {"width": 1024, "height": 960},
+ # portrait
+ {"width": 960, "height": 1024},
+ {"width": 960, "height": 1088},
+ {"width": 896, "height": 1088},
+ {"width": 896, "height": 1152}, # 2:3
+ {"width": 832, "height": 1152},
+ {"width": 832, "height": 1216},
+ {"width": 768, "height": 1280},
+ {"width": 768, "height": 1344},
+ {"width": 704, "height": 1408},
+ {"width": 704, "height": 1472},
+ {"width": 640, "height": 1536},
+ {"width": 640, "height": 1600},
+ {"width": 576, "height": 1664},
+ {"width": 576, "height": 1728},
+ {"width": 576, "height": 1792},
+ {"width": 512, "height": 1856},
+ {"width": 512, "height": 1920},
+ {"width": 512, "height": 1984},
+ {"width": 512, "height": 2048},
+]
+
+
+def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
+ # determine scaler form 1024 to resolution
+ scaler = resolution / 1024
+
+ bucket_size_list = []
+ for bucket in resolutions_1024:
+ # must be divisible by 8
+ width = int(bucket["width"] * scaler)
+ height = int(bucket["height"] * scaler)
+ if width % divisibility != 0:
+ width = width - (width % divisibility)
+ if height % divisibility != 0:
+ height = height - (height % divisibility)
+ bucket_size_list.append({"width": width, "height": height})
+
+ return bucket_size_list
+
+
+def get_resolution(width, height):
+ num_pixels = width * height
+ # determine same number of pixels for square image
+ square_resolution = int(num_pixels ** 0.5)
+ return square_resolution
+
+
+def get_bucket_for_image_size(
+ width: int,
+ height: int,
+ bucket_size_list: List[BucketResolution] = None,
+ resolution: Union[int, None] = None,
+ divisibility: int = 8
+) -> BucketResolution:
+
+ if bucket_size_list is None and resolution is None:
+ # get resolution from width and height
+ resolution = get_resolution(width, height)
+ if bucket_size_list is None:
+ # if real resolution is smaller, use that instead
+ real_resolution = get_resolution(width, height)
+ resolution = min(resolution, real_resolution)
+ bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
+
+ # Check for exact match first
+ for bucket in bucket_size_list:
+ if bucket["width"] == width and bucket["height"] == height:
+ return bucket
+
+ # If exact match not found, find the closest bucket
+ closest_bucket = None
+ min_removed_pixels = float("inf")
+
+ for bucket in bucket_size_list:
+ scale_w = bucket["width"] / width
+ scale_h = bucket["height"] / height
+
+ # To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped.
+ scale = max(scale_w, scale_h)
+
+ new_width = int(width * scale)
+ new_height = int(height * scale)
+
+ removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width
+
+ if removed_pixels < min_removed_pixels:
+ min_removed_pixels = removed_pixels
+ closest_bucket = bucket
+
+ if closest_bucket is None:
+ raise ValueError("No suitable bucket found")
+
+ return closest_bucket
diff --git a/toolkit/civitai.py b/toolkit/civitai.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef505ad833f951470eb2e6a9c7b26059a6509604
--- /dev/null
+++ b/toolkit/civitai.py
@@ -0,0 +1,217 @@
+from toolkit.paths import MODELS_PATH
+import requests
+import os
+import json
+import tqdm
+
+
+class ModelCache:
+ def __init__(self):
+ self.raw_cache = {}
+ self.cache_path = os.path.join(MODELS_PATH, '.ai_toolkit_cache.json')
+ if os.path.exists(self.cache_path):
+ with open(self.cache_path, 'r') as f:
+ all_cache = json.load(f)
+ if 'models' in all_cache:
+ self.raw_cache = all_cache['models']
+ else:
+ self.raw_cache = all_cache
+
+ def get_model_path(self, model_id: int, model_version_id: int = None):
+ if str(model_id) not in self.raw_cache:
+ return None
+ if model_version_id is None:
+ # get latest version
+ model_version_id = max([int(x) for x in self.raw_cache[str(model_id)].keys()])
+ if model_version_id is None:
+ return None
+ model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
+ # check if model path exists
+ if not os.path.exists(model_path):
+ # remove version from cache
+ del self.raw_cache[str(model_id)][str(model_version_id)]
+ self.save()
+ return None
+ return model_path
+ else:
+ if str(model_version_id) not in self.raw_cache[str(model_id)]:
+ return None
+ model_path = self.raw_cache[str(model_id)][str(model_version_id)]['model_path']
+ # check if model path exists
+ if not os.path.exists(model_path):
+ # remove version from cache
+ del self.raw_cache[str(model_id)][str(model_version_id)]
+ self.save()
+ return None
+ return model_path
+
+ def update_cache(self, model_id: int, model_version_id: int, model_path: str):
+ if str(model_id) not in self.raw_cache:
+ self.raw_cache[str(model_id)] = {}
+ if str(model_version_id) not in self.raw_cache[str(model_id)]:
+ self.raw_cache[str(model_id)][str(model_version_id)] = {}
+ self.raw_cache[str(model_id)][str(model_version_id)] = {
+ 'model_path': model_path
+ }
+ self.save()
+
+ def save(self):
+ if not os.path.exists(os.path.dirname(self.cache_path)):
+ os.makedirs(os.path.dirname(self.cache_path), exist_ok=True)
+ all_cache = {'models': {}}
+ if os.path.exists(self.cache_path):
+ # load it first
+ with open(self.cache_path, 'r') as f:
+ all_cache = json.load(f)
+
+ all_cache['models'] = self.raw_cache
+
+ with open(self.cache_path, 'w') as f:
+ json.dump(all_cache, f, indent=2)
+
+
+def get_model_download_info(model_id: int, model_version_id: int = None):
+ # curl https://civitai.com/api/v1/models?limit=3&types=TextualInversion \
+ # -H "Content-Type: application/json" \
+ # -X GET
+ print(
+ f"Getting model info for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
+ endpoint = f"https://civitai.com/api/v1/models/{model_id}"
+
+ # get the json
+ response = requests.get(endpoint)
+ response.raise_for_status()
+ model_data = response.json()
+
+ model_version = None
+
+ # go through versions and get the top one if one is not set
+ for version in model_data['modelVersions']:
+ if model_version_id is not None:
+ if str(version['id']) == str(model_version_id):
+ model_version = version
+ break
+ else:
+ # get first version
+ model_version = version
+ break
+
+ if model_version is None:
+ raise ValueError(
+ f"Could not find a model version for model id: {model_id}{f' and version id: {model_version_id}' if model_version_id is not None else ''}")
+
+ model_file = None
+ # go through files and prefer fp16 safetensors
+ # "metadata": {
+ # "fp": "fp16",
+ # "size": "pruned",
+ # "format": "SafeTensor"
+ # },
+ # todo check pickle scans and skip if not good
+ # try to get fp16 safetensor
+ for file in model_version['files']:
+ if file['metadata']['fp'] == 'fp16' and file['metadata']['format'] == 'SafeTensor':
+ model_file = file
+ break
+
+ if model_file is None:
+ # try to get primary
+ for file in model_version['files']:
+ if file['primary']:
+ model_file = file
+ break
+
+ if model_file is None:
+ # try to get any safetensor
+ for file in model_version['files']:
+ if file['metadata']['format'] == 'SafeTensor':
+ model_file = file
+ break
+
+ if model_file is None:
+ # try to get any fp16
+ for file in model_version['files']:
+ if file['metadata']['fp'] == 'fp16':
+ model_file = file
+ break
+
+ if model_file is None:
+ # try to get any
+ for file in model_version['files']:
+ model_file = file
+ break
+
+ if model_file is None:
+ raise ValueError(f"Could not find a model file to download for model id: {model_id}")
+
+ return model_file, model_version['id']
+
+
+def get_model_path_from_url(url: str):
+ # get query params form url if they are set
+ # https: // civitai.com / models / 25694?modelVersionId = 127742
+ query_params = {}
+ if '?' in url:
+ query_string = url.split('?')[1]
+ query_params = dict(qc.split("=") for qc in query_string.split("&"))
+
+ # get model id from url
+ model_id = url.split('/')[-1]
+ # remove query params from model id
+ if '?' in model_id:
+ model_id = model_id.split('?')[0]
+ if model_id.isdigit():
+ model_id = int(model_id)
+ else:
+ raise ValueError(f"Invalid model id: {model_id}")
+
+ model_cache = ModelCache()
+ model_path = model_cache.get_model_path(model_id, query_params.get('modelVersionId', None))
+ if model_path is not None:
+ return model_path
+ else:
+ # download model
+ file_info, model_version_id = get_model_download_info(model_id, query_params.get('modelVersionId', None))
+
+ download_url = file_info['downloadUrl'] # url does not work directly
+ size_kb = file_info['sizeKB']
+ filename = file_info['name']
+ model_path = os.path.join(MODELS_PATH, filename)
+
+ # download model
+ print(f"Did not find model locally, downloading from model from: {download_url}")
+
+ # use tqdm to show status of downlod
+ response = requests.get(download_url, stream=True)
+ response.raise_for_status()
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+ tmp_path = os.path.join(MODELS_PATH, f".download_tmp_{filename}")
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ # remove tmp file if it exists
+ if os.path.exists(tmp_path):
+ os.remove(tmp_path)
+
+ try:
+
+ with open(tmp_path, 'wb') as f:
+ for data in response.iter_content(block_size):
+ progress_bar.update(len(data))
+ f.write(data)
+ progress_bar.close()
+ # move to final path
+ os.rename(tmp_path, model_path)
+ model_cache.update_cache(model_id, model_version_id, model_path)
+
+ return model_path
+ except Exception as e:
+ # remove tmp file
+ os.remove(tmp_path)
+ raise e
+
+
+# if is main
+if __name__ == '__main__':
+ model_path = get_model_path_from_url("https://civitai.com/models/25694?modelVersionId=127742")
+ print(model_path)
diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ccc920caac68aa43a2b1ddc944079d88feb50a4
--- /dev/null
+++ b/toolkit/clip_vision_adapter.py
@@ -0,0 +1,406 @@
+from typing import TYPE_CHECKING, Mapping, Any
+
+import torch
+import weakref
+
+from toolkit.config_modules import AdapterConfig
+from toolkit.models.clip_fusion import ZipperBlock
+from toolkit.models.zipper_resampler import ZipperModule
+from toolkit.prompt_utils import PromptEmbeds
+from toolkit.train_tools import get_torch_dtype
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ CLIPVisionModel
+)
+
+from toolkit.resampler import Resampler
+
+import torch.nn as nn
+
+
+class Embedder(nn.Module):
+ def __init__(
+ self,
+ num_input_tokens: int = 1,
+ input_dim: int = 1024,
+ num_output_tokens: int = 8,
+ output_dim: int = 768,
+ mid_dim: int = 1024
+ ):
+ super(Embedder, self).__init__()
+ self.num_output_tokens = num_output_tokens
+ self.num_input_tokens = num_input_tokens
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+
+ self.layer_norm = nn.LayerNorm(input_dim)
+ self.fc1 = nn.Linear(input_dim, mid_dim)
+ self.gelu = nn.GELU()
+ # self.fc2 = nn.Linear(mid_dim, mid_dim)
+ self.fc2 = nn.Linear(mid_dim, mid_dim)
+
+ self.fc2.weight.data.zero_()
+
+ self.layer_norm2 = nn.LayerNorm(mid_dim)
+ self.fc3 = nn.Linear(mid_dim, mid_dim)
+ self.gelu2 = nn.GELU()
+ self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
+
+ # set the weights to 0
+ self.fc3.weight.data.zero_()
+ self.fc4.weight.data.zero_()
+
+
+ # self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
+ # self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = x.unsqueeze(1)
+ x = self.layer_norm(x)
+ x = self.fc1(x)
+ x = self.gelu(x)
+ x = self.fc2(x)
+ x = self.layer_norm2(x)
+ x = self.fc3(x)
+ x = self.gelu2(x)
+ x = self.fc4(x)
+
+ x = x.view(-1, self.num_output_tokens, self.output_dim)
+
+ return x
+
+
+class ClipVisionAdapter(torch.nn.Module):
+ def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig):
+ super().__init__()
+ self.config = adapter_config
+ self.trigger = adapter_config.trigger
+ self.trigger_class_name = adapter_config.trigger_class_name
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ # embedding stuff
+ self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder]
+ self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer]
+ placeholder_tokens = [self.trigger]
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, self.config.num_tokens):
+ additional_tokens.append(f"{self.trigger}_{i}")
+ placeholder_tokens += additional_tokens
+
+ # handle dual tokenizer
+ self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [
+ self.sd_ref().tokenizer]
+ self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [
+ self.sd_ref().text_encoder]
+
+ self.placeholder_token_ids = []
+ self.embedding_tokens = []
+
+ print(f"Adding {placeholder_tokens} tokens to tokenizer")
+ print(f"Adding {self.config.num_tokens} tokens to tokenizer")
+
+
+ for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != self.config.num_tokens:
+ raise ValueError(
+ f"The tokenizer already contains the token {self.trigger}. Please pass a different"
+ f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False)
+ # if length of token ids is more than number of orm embedding tokens fill with *
+ if len(init_token_ids) > self.config.num_tokens:
+ init_token_ids = init_token_ids[:self.config.num_tokens]
+ elif len(init_token_ids) < self.config.num_tokens:
+ pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
+ init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids))
+
+ placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
+ self.placeholder_token_ids.append(placeholder_token_ids)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # replace "[name] with this. on training. This is automatically generated in pipeline on inference
+ self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
+
+ # backup text encoder embeddings
+ self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
+
+ try:
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
+ except EnvironmentError:
+ self.clip_image_processor = CLIPImageProcessor()
+ self.device = self.sd_ref().unet.device
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ self.config.image_encoder_path,
+ ignore_mismatched_sizes=True
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ if self.config.train_image_encoder:
+ self.image_encoder.train()
+ else:
+ self.image_encoder.eval()
+
+ # max_seq_len = CLIP tokens + CLS token
+ image_encoder_state_dict = self.image_encoder.state_dict()
+ in_tokens = 257
+ if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
+ # clip
+ in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
+
+ if hasattr(self.image_encoder.config, 'hidden_sizes'):
+ embedding_dim = self.image_encoder.config.hidden_sizes[-1]
+ else:
+ embedding_dim = self.image_encoder.config.target_hidden_size
+
+ if self.config.clip_layer == 'image_embeds':
+ in_tokens = 1
+ embedding_dim = self.image_encoder.config.projection_dim
+
+ self.embedder = Embedder(
+ num_output_tokens=self.config.num_tokens,
+ num_input_tokens=in_tokens,
+ input_dim=embedding_dim,
+ output_dim=self.sd_ref().unet.config['cross_attention_dim'],
+ mid_dim=embedding_dim * self.config.num_tokens,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+
+ self.embedder.train()
+
+ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
+ state_dict = {
+ 'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
+ }
+ if self.config.train_image_encoder:
+ state_dict['image_encoder'] = self.image_encoder.state_dict(
+ *args, destination=destination, prefix=prefix,
+ keep_vars=keep_vars)
+
+ return state_dict
+
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+ self.embedder.load_state_dict(state_dict["embedder"], strict=strict)
+ if self.config.train_image_encoder and 'image_encoder' in state_dict:
+ self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
+
+ def parameters(self, *args, **kwargs):
+ yield from self.embedder.parameters(*args, **kwargs)
+
+ def named_parameters(self, *args, **kwargs):
+ yield from self.embedder.named_parameters(*args, **kwargs)
+
+ def get_clip_image_embeds_from_tensors(
+ self, tensors_0_1: torch.Tensor, drop=False,
+ is_training=False,
+ has_been_preprocessed=False
+ ) -> torch.Tensor:
+ with torch.no_grad():
+ if not has_been_preprocessed:
+ # tensors should be 0-1
+ if tensors_0_1.ndim == 3:
+ tensors_0_1 = tensors_0_1.unsqueeze(0)
+ # training tensors are 0 - 1
+ tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
+
+ # if images are out of this range throw error
+ if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
+ raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
+ tensors_0_1.min(), tensors_0_1.max()
+ ))
+ # unconditional
+ if drop:
+ if self.clip_noise_zero:
+ tensors_0_1 = torch.rand_like(tensors_0_1).detach()
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ else:
+ tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
+ # tensors_0_1 = tensors_0_1 * 0
+ clip_image = self.clip_image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+ else:
+ if drop:
+ # scale the noise down
+ if self.clip_noise_zero:
+ tensors_0_1 = torch.rand_like(tensors_0_1).detach()
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ else:
+ tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
+ # tensors_0_1 = tensors_0_1 * 0
+ mean = torch.tensor(self.clip_image_processor.image_mean).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ std = torch.tensor(self.clip_image_processor.image_std).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
+ clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
+
+ else:
+ clip_image = tensors_0_1
+ clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
+ with torch.set_grad_enabled(is_training):
+ if is_training:
+ self.image_encoder.train()
+ else:
+ self.image_encoder.eval()
+ clip_output = self.image_encoder(clip_image, output_hidden_states=True)
+
+ if self.config.clip_layer == 'penultimate_hidden_states':
+ # they skip last layer for ip+
+ # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
+ clip_image_embeds = clip_output.hidden_states[-2]
+ elif self.config.clip_layer == 'last_hidden_state':
+ clip_image_embeds = clip_output.hidden_states[-1]
+ else:
+ clip_image_embeds = clip_output.image_embeds
+ return clip_image_embeds
+
+ import torch
+
+ def set_vec(self, new_vector, text_encoder_idx=0):
+ # Get the embedding layer
+ embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings()
+
+ # Indices to replace in the embeddings
+ indices_to_replace = self.placeholder_token_ids[text_encoder_idx]
+
+ # Replace the specified embeddings with new_vector
+ for idx in indices_to_replace:
+ vector_idx = idx - indices_to_replace[0]
+ embedding_layer.weight[idx] = new_vector[vector_idx]
+
+ # adds it to the tokenizer
+ def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ if clip_image_embeds.ndim == 2:
+ # expand the token dimension
+ clip_image_embeds = clip_image_embeds.unsqueeze(1)
+ image_prompt_embeds = self.embedder(clip_image_embeds)
+ # todo add support for multiple batch sizes
+ if image_prompt_embeds.shape[0] != 1:
+ raise ValueError("Batch size must be 1 for embedder for now")
+
+ # output on sd1.5 is bs, num_tokens, 768
+ if len(self.text_encoder_list) == 1:
+ # add it to the text encoder
+ self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
+ elif len(self.text_encoder_list) == 2:
+ if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \
+ image_prompt_embeds.shape[2]:
+ raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
+ # sdxl variants
+ # image_prompt_embeds = 2048
+ # te1 = 768
+ # te2 = 1280
+ te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size]
+ te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:]
+ self.set_vec(te1_embeds[0], text_encoder_idx=0)
+ self.set_vec(te2_embeds[0], text_encoder_idx=1)
+ else:
+
+ raise ValueError("Unsupported number of text encoders")
+ # just a place to put a breakpoint
+ pass
+
+ def restore_embeddings(self):
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(
+ self.text_encoder_list,
+ self.tokenizer_list,
+ self.orig_embeds_params,
+ self.placeholder_token_ids
+ ):
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[
+ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
+ with torch.no_grad():
+ text_encoder.get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds[index_no_updates]
+ # detach it all
+ text_encoder.get_input_embeddings().weight.detach_()
+
+ def enable_gradient_checkpointing(self):
+ self.image_encoder.gradient_checkpointing = True
+
+ def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
+ output_prompt = prompt
+ embedding_tokens = self.embedding_tokens[0] # shoudl be the same
+ default_replacements = ["[name]", "[trigger]"]
+
+ replace_with = embedding_tokens if expand_token else self.trigger
+ if to_replace_list is None:
+ to_replace_list = default_replacements
+ else:
+ to_replace_list += default_replacements
+
+ # remove duplicates
+ to_replace_list = list(set(to_replace_list))
+
+ # replace them all
+ for to_replace in to_replace_list:
+ # replace it
+ output_prompt = output_prompt.replace(to_replace, replace_with)
+
+ # see how many times replace_with is in the prompt
+ num_instances = output_prompt.count(replace_with)
+
+ if num_instances == 0 and add_if_not_present:
+ # add it to the beginning of the prompt
+ output_prompt = replace_with + " " + output_prompt
+
+ if num_instances > 1:
+ print(
+ f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
+
+ return output_prompt
+
+ # reverses injection with class name. useful for normalizations
+ def inject_trigger_class_name_into_prompt(self, prompt):
+ output_prompt = prompt
+ embedding_tokens = self.embedding_tokens[0] # shoudl be the same
+
+ default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger]
+
+ replace_with = self.config.trigger_class_name
+ to_replace_list = default_replacements
+
+ # remove duplicates
+ to_replace_list = list(set(to_replace_list))
+
+ # replace them all
+ for to_replace in to_replace_list:
+ # replace it
+ output_prompt = output_prompt.replace(to_replace, replace_with)
+
+ # see how many times replace_with is in the prompt
+ num_instances = output_prompt.count(replace_with)
+
+ if num_instances > 1:
+ print(
+ f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
+
+ return output_prompt
diff --git a/toolkit/config.py b/toolkit/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..52de47b836540c319e3ca8faa479312619139769
--- /dev/null
+++ b/toolkit/config.py
@@ -0,0 +1,110 @@
+import os
+import json
+from typing import Union
+
+import oyaml as yaml
+import re
+from collections import OrderedDict
+
+from toolkit.paths import TOOLKIT_ROOT
+
+possible_extensions = ['.json', '.jsonc', '.yaml', '.yml']
+
+
+def get_cwd_abs_path(path):
+ if not os.path.isabs(path):
+ path = os.path.join(os.getcwd(), path)
+ return path
+
+
+def replace_env_vars_in_string(s: str) -> str:
+ """
+ Replace placeholders like ${VAR_NAME} with the value of the corresponding environment variable.
+ If the environment variable is not set, raise an error.
+ """
+
+ def replacer(match):
+ var_name = match.group(1)
+ value = os.environ.get(var_name)
+
+ if value is None:
+ raise ValueError(f"Environment variable {var_name} not set. Please ensure it's defined before proceeding.")
+
+ return value
+
+ return re.sub(r'\$\{([^}]+)\}', replacer, s)
+
+
+def preprocess_config(config: OrderedDict, name: str = None):
+ if "job" not in config:
+ raise ValueError("config file must have a job key")
+ if "config" not in config:
+ raise ValueError("config file must have a config section")
+ if "name" not in config["config"] and name is None:
+ raise ValueError("config file must have a config.name key")
+ # we need to replace tags. For now just [name]
+ if name is None:
+ name = config["config"]["name"]
+ config_string = json.dumps(config)
+ config_string = config_string.replace("[name]", name)
+ config = json.loads(config_string, object_pairs_hook=OrderedDict)
+ return config
+
+
+# Fixes issue where yaml doesnt load exponents correctly
+fixed_loader = yaml.SafeLoader
+fixed_loader.add_implicit_resolver(
+ u'tag:yaml.org,2002:float',
+ re.compile(u'''^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$''', re.X),
+ list(u'-+0123456789.'))
+
+
+def get_config(
+ config_file_path_or_dict: Union[str, dict, OrderedDict],
+ name=None
+):
+ # if we got a dict, process it and return it
+ if isinstance(config_file_path_or_dict, dict) or isinstance(config_file_path_or_dict, OrderedDict):
+ config = config_file_path_or_dict
+ return preprocess_config(config, name)
+
+ config_file_path = config_file_path_or_dict
+
+ # first check if it is in the config folder
+ config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
+ # see if it is in the config folder with any of the possible extensions if it doesnt have one
+ real_config_path = None
+ if not os.path.exists(config_path):
+ for ext in possible_extensions:
+ if os.path.exists(config_path + ext):
+ real_config_path = config_path + ext
+ break
+
+ # if we didn't find it there, check if it is a full path
+ if not real_config_path:
+ if os.path.exists(config_file_path):
+ real_config_path = config_file_path
+ elif os.path.exists(get_cwd_abs_path(config_file_path)):
+ real_config_path = get_cwd_abs_path(config_file_path)
+
+ if not real_config_path:
+ raise ValueError(f"Could not find config file {config_file_path}")
+
+ # if we found it, check if it is a json or yaml file
+ with open(real_config_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ content_with_env_replaced = replace_env_vars_in_string(content)
+ if real_config_path.endswith('.json') or real_config_path.endswith('.jsonc'):
+ config = json.loads(content_with_env_replaced, object_pairs_hook=OrderedDict)
+ elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
+ config = yaml.load(content_with_env_replaced, Loader=fixed_loader)
+ else:
+ raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
+
+ return preprocess_config(config, name)
diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e7215bf76e7fef5daee5c0000b6c1053608745d
--- /dev/null
+++ b/toolkit/config_modules.py
@@ -0,0 +1,927 @@
+import os
+import time
+from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
+import random
+
+import torch
+
+from toolkit.prompt_utils import PromptEmbeds
+
+ImgExt = Literal['jpg', 'png', 'webp']
+
+SaveFormat = Literal['safetensors', 'diffusers']
+
+if TYPE_CHECKING:
+ from toolkit.guidance import GuidanceType
+ from toolkit.logging import EmptyLogger
+else:
+ EmptyLogger = None
+
+class SaveConfig:
+ def __init__(self, **kwargs):
+ self.save_every: int = kwargs.get('save_every', 1000)
+ self.dtype: str = kwargs.get('dtype', 'float16')
+ self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
+ self.save_format: SaveFormat = kwargs.get('save_format', 'safetensors')
+ if self.save_format not in ['safetensors', 'diffusers']:
+ raise ValueError(f"save_format must be safetensors or diffusers, got {self.save_format}")
+ self.push_to_hub: bool = kwargs.get("push_to_hub", False)
+ self.hf_repo_id: Optional[str] = kwargs.get("hf_repo_id", None)
+ self.hf_private: Optional[str] = kwargs.get("hf_private", False)
+
+class LoggingConfig:
+ def __init__(self, **kwargs):
+ self.log_every: int = kwargs.get('log_every', 100)
+ self.verbose: bool = kwargs.get('verbose', False)
+ self.use_wandb: bool = kwargs.get('use_wandb', False)
+ self.project_name: str = kwargs.get('project_name', 'ai-toolkit')
+ self.run_name: str = kwargs.get('run_name', None)
+
+
+class SampleConfig:
+ def __init__(self, **kwargs):
+ self.sampler: str = kwargs.get('sampler', 'ddpm')
+ self.sample_every: int = kwargs.get('sample_every', 100)
+ self.width: int = kwargs.get('width', 512)
+ self.height: int = kwargs.get('height', 512)
+ self.prompts: list[str] = kwargs.get('prompts', [])
+ self.neg = kwargs.get('neg', False)
+ self.seed = kwargs.get('seed', 0)
+ self.walk_seed = kwargs.get('walk_seed', False)
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
+ self.sample_steps = kwargs.get('sample_steps', 20)
+ self.network_multiplier = kwargs.get('network_multiplier', 1)
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
+ self.ext: ImgExt = kwargs.get('format', 'jpg')
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
+ self.refiner_start_at = kwargs.get('refiner_start_at',
+ 0.5) # step to start using refiner on sample if it exists
+ self.extra_values = kwargs.get('extra_values', [])
+
+
+class LormModuleSettingsConfig:
+ def __init__(self, **kwargs):
+ self.contains: str = kwargs.get('contains', '4nt$3')
+ self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
+ # min num parameters to attach to
+ self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
+ self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
+
+
+class LoRMConfig:
+ def __init__(self, **kwargs):
+ self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
+ self.do_conv: bool = kwargs.get('do_conv', False)
+ self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
+ self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
+ module_settings = kwargs.get('module_settings', [])
+ default_module_settings = {
+ 'extract_mode': self.extract_mode,
+ 'extract_mode_param': self.extract_mode_param,
+ 'parameter_threshold': self.parameter_threshold,
+ }
+ module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
+ self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
+ module_setting in module_settings]
+
+ def get_config_for_module(self, block_name):
+ for setting in self.module_settings:
+ contain_pieces = setting.contains.split('|')
+ if all(contain_piece in block_name for contain_piece in contain_pieces):
+ return setting
+ # try replacing the . with _
+ contain_pieces = setting.contains.replace('.', '_').split('|')
+ if all(contain_piece in block_name for contain_piece in contain_pieces):
+ return setting
+ # do default
+ return LormModuleSettingsConfig(**{
+ 'extract_mode': self.extract_mode,
+ 'extract_mode_param': self.extract_mode_param,
+ 'parameter_threshold': self.parameter_threshold,
+ })
+
+
+NetworkType = Literal['lora', 'locon', 'lorm']
+
+
+class NetworkConfig:
+ def __init__(self, **kwargs):
+ self.type: NetworkType = kwargs.get('type', 'lora')
+ rank = kwargs.get('rank', None)
+ linear = kwargs.get('linear', None)
+ if rank is not None:
+ self.rank: int = rank # rank for backward compatibility
+ self.linear: int = rank
+ elif linear is not None:
+ self.rank: int = linear
+ self.linear: int = linear
+ self.conv: int = kwargs.get('conv', None)
+ self.alpha: float = kwargs.get('alpha', 1.0)
+ self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
+ self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
+ self.dropout: Union[float, None] = kwargs.get('dropout', None)
+ self.network_kwargs: dict = kwargs.get('network_kwargs', {})
+
+ self.lorm_config: Union[LoRMConfig, None] = None
+ lorm = kwargs.get('lorm', None)
+ if lorm is not None:
+ self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
+
+ if self.type == 'lorm':
+ # set linear to arbitrary values so it makes them
+ self.linear = 4
+ self.rank = 4
+ if self.lorm_config.do_conv:
+ self.conv = 4
+
+ self.transformer_only = kwargs.get('transformer_only', True)
+
+
+AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net']
+
+CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
+
+
+class AdapterConfig:
+ def __init__(self, **kwargs):
+ self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
+ self.in_channels: int = kwargs.get('in_channels', 3)
+ self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
+ self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
+ self.downscale_factor: int = kwargs.get('downscale_factor', 8)
+ self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
+ self.image_dir: str = kwargs.get('image_dir', None)
+ self.test_img_path: str = kwargs.get('test_img_path', None)
+ self.train: str = kwargs.get('train', False)
+ self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
+ self.name_or_path = kwargs.get('name_or_path', None)
+
+ num_tokens = kwargs.get('num_tokens', None)
+ if num_tokens is None and self.type.startswith('ip'):
+ if self.type == 'ip+':
+ num_tokens = 16
+ num_tokens = 16
+ elif self.type == 'ip':
+ num_tokens = 4
+
+ self.num_tokens: int = num_tokens
+ self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
+ self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
+ if self.train_only_image_encoder:
+ self.train_image_encoder = True
+ self.train_only_image_encoder_positional_embedding: bool = kwargs.get(
+ 'train_only_image_encoder_positional_embedding', False)
+ self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
+ self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
+ self.safe_channels: int = kwargs.get('safe_channels', 2048)
+ self.safe_tokens: int = kwargs.get('safe_tokens', 8)
+ self.quad_image: bool = kwargs.get('quad_image', False)
+
+ # clip vision
+ self.trigger = kwargs.get('trigger', 'tri993r')
+ self.trigger_class_name = kwargs.get('trigger_class_name', None)
+
+ self.class_names = kwargs.get('class_names', [])
+
+ self.clip_layer: CLIPLayer = kwargs.get('clip_layer', None)
+ if self.clip_layer is None:
+ if self.type.startswith('ip+'):
+ self.clip_layer = 'penultimate_hidden_states'
+ else:
+ self.clip_layer = 'last_hidden_state'
+
+ # text encoder
+ self.text_encoder_path: str = kwargs.get('text_encoder_path', None)
+ self.text_encoder_arch: str = kwargs.get('text_encoder_arch', 'clip') # clip t5
+
+ self.train_scaler: bool = kwargs.get('train_scaler', False)
+ self.scaler_lr: Optional[float] = kwargs.get('scaler_lr', None)
+
+ # trains with a scaler to easy channel bias but merges it in on save
+ self.merge_scaler: bool = kwargs.get('merge_scaler', False)
+
+ # for ilora
+ self.head_dim: int = kwargs.get('head_dim', 1024)
+ self.num_heads: int = kwargs.get('num_heads', 1)
+ self.ilora_down: bool = kwargs.get('ilora_down', True)
+ self.ilora_mid: bool = kwargs.get('ilora_mid', True)
+ self.ilora_up: bool = kwargs.get('ilora_up', True)
+
+ self.pixtral_max_image_size: int = kwargs.get('pixtral_max_image_size', 512)
+ self.pixtral_random_image_size: int = kwargs.get('pixtral_random_image_size', False)
+
+ self.flux_only_double: bool = kwargs.get('flux_only_double', False)
+
+ # train and use a conv layer to pool the embedding
+ self.conv_pooling: bool = kwargs.get('conv_pooling', False)
+ self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1)
+ self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None)
+
+
+class EmbeddingConfig:
+ def __init__(self, **kwargs):
+ self.trigger = kwargs.get('trigger', 'custom_embedding')
+ self.tokens = kwargs.get('tokens', 4)
+ self.init_words = kwargs.get('init_words', '*')
+ self.save_format = kwargs.get('save_format', 'safetensors')
+ self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior
+
+
+class DecoratorConfig:
+ def __init__(self, **kwargs):
+ self.num_tokens: str = kwargs.get('num_tokens', 4)
+
+
+ContentOrStyleType = Literal['balanced', 'style', 'content']
+LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
+
+
+class TrainConfig:
+ def __init__(self, **kwargs):
+ self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
+ self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
+ self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
+ self.steps: int = kwargs.get('steps', 1000)
+ self.lr = kwargs.get('lr', 1e-6)
+ self.unet_lr = kwargs.get('unet_lr', self.lr)
+ self.text_encoder_lr = kwargs.get('text_encoder_lr', self.lr)
+ self.refiner_lr = kwargs.get('refiner_lr', self.lr)
+ self.embedding_lr = kwargs.get('embedding_lr', self.lr)
+ self.adapter_lr = kwargs.get('adapter_lr', self.lr)
+ self.optimizer = kwargs.get('optimizer', 'adamw')
+ self.optimizer_params = kwargs.get('optimizer_params', {})
+ self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
+ self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
+ self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
+ self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
+ self.batch_size: int = kwargs.get('batch_size', 1)
+ self.orig_batch_size: int = self.batch_size
+ self.dtype: str = kwargs.get('dtype', 'fp32')
+ self.xformers = kwargs.get('xformers', False)
+ self.sdp = kwargs.get('sdp', False)
+ self.train_unet = kwargs.get('train_unet', True)
+ self.train_text_encoder = kwargs.get('train_text_encoder', False)
+ self.train_refiner = kwargs.get('train_refiner', True)
+ self.train_turbo = kwargs.get('train_turbo', False)
+ self.show_turbo_outputs = kwargs.get('show_turbo_outputs', False)
+ self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
+ self.snr_gamma = kwargs.get('snr_gamma', None)
+ # trains a gamma, offset, and scale to adjust loss to adapt to timestep differentials
+ # this should balance the learning rate across all timesteps over time
+ self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
+ self.noise_offset = kwargs.get('noise_offset', 0.0)
+ self.skip_first_sample = kwargs.get('skip_first_sample', False)
+ self.force_first_sample = kwargs.get('force_first_sample', False)
+ self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
+ self.weight_jitter = kwargs.get('weight_jitter', 0.0)
+ self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
+ self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
+ self.start_step = kwargs.get('start_step', None)
+ self.free_u = kwargs.get('free_u', False)
+ self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
+ self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
+ self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
+ self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0)
+ self.img_multiplier = kwargs.get('img_multiplier', 1.0)
+ self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
+ self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
+ self.negative_prompt = kwargs.get('negative_prompt', None)
+ self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)
+ # multiplier applied to loos on regularization images
+ self.reg_weight = kwargs.get('reg_weight', 1.0)
+ self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
+ self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
+ # automatically adapte the vae scaling based on the image norm
+ self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False)
+
+ # dropout that happens before encoding. It functions independently per text encoder
+ self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
+
+ # match the norm of the noise before computing loss. This will help the model maintain its
+ # current understandin of the brightness of images.
+
+ self.match_noise_norm = kwargs.get('match_noise_norm', False)
+
+ # set to -1 to accumulate gradients for entire epoch
+ # warning, only do this with a small dataset or you will run out of memory
+ # This is legacy but left in for backwards compatibility
+ self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
+
+ # this will do proper gradient accumulation where you will not see a step until the end of the accumulation
+ # the method above will show a step every accumulation
+ self.gradient_accumulation = kwargs.get('gradient_accumulation', 1)
+ if self.gradient_accumulation > 1:
+ if self.gradient_accumulation_steps != 1:
+ raise ValueError("gradient_accumulation and gradient_accumulation_steps are mutually exclusive")
+
+ # short long captions will double your batch size. This only works when a dataset is
+ # prepared with a json caption file that has both short and long captions in it. It will
+ # Double up every image and run it through with both short and long captions. The idea
+ # is that the network will learn how to generate good images with both short and long captions
+ self.short_and_long_captions = kwargs.get('short_and_long_captions', False)
+ # if above is NOT true, this will make it so the long caption foes to te2 and the short caption goes to te1 for sdxl only
+ self.short_and_long_captions_encoder_split = kwargs.get('short_and_long_captions_encoder_split', False)
+
+ # basically gradient accumulation but we run just 1 item through the network
+ # and accumulate gradients. This can be used as basic gradient accumulation but is very helpful
+ # for training tricks that increase batch size but need a single gradient step
+ self.single_item_batching = kwargs.get('single_item_batching', False)
+
+ match_adapter_assist = kwargs.get('match_adapter_assist', False)
+ self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
+ self.loss_target: LossTarget = kwargs.get('loss_target',
+ 'noise') # noise, source, unaugmented, differential_noise
+
+ # When a mask is passed in a dataset, and this is true,
+ # we will predict noise without a the LoRa network and use the prediction as a target for
+ # unmasked reign. It is unmasked regularization basically
+ self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
+ self.inverted_mask_prior_multiplier = kwargs.get('inverted_mask_prior_multiplier', 0.5)
+
+ # legacy
+ if match_adapter_assist and self.match_adapter_chance == 0.0:
+ self.match_adapter_chance = 1.0
+
+ # standardize inputs to the meand std of the model knowledge
+ self.standardize_images = kwargs.get('standardize_images', False)
+ self.standardize_latents = kwargs.get('standardize_latents', False)
+
+ if self.train_turbo and not self.noise_scheduler.startswith("euler"):
+ raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
+
+ self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
+ self.do_cfg = kwargs.get('do_cfg', False)
+ self.do_random_cfg = kwargs.get('do_random_cfg', False)
+ self.cfg_scale = kwargs.get('cfg_scale', 1.0)
+ self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
+ self.cfg_rescale = kwargs.get('cfg_rescale', None)
+ if self.cfg_rescale is None:
+ self.cfg_rescale = self.cfg_scale
+
+ # applies the inverse of the prediction mean and std to the target to correct
+ # for norm drift
+ self.correct_pred_norm = kwargs.get('correct_pred_norm', False)
+ self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0)
+
+ self.loss_type = kwargs.get('loss_type', 'mse')
+
+ # scale the prediction by this. Increase for more detail, decrease for less
+ self.pred_scaler = kwargs.get('pred_scaler', 1.0)
+
+ # repeats the prompt a few times to saturate the encoder
+ self.prompt_saturation_chance = kwargs.get('prompt_saturation_chance', 0.0)
+
+ # applies negative loss on the prior to encourage network to diverge from it
+ self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
+
+ ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
+ if ema_config is not None:
+ ema_config['use_ema'] = True
+ print(f"Using EMA")
+ else:
+ ema_config = {'use_ema': False}
+
+ self.ema_config: EMAConfig = EMAConfig(**ema_config)
+
+ # adds an additional loss to the network to encourage it output a normalized standard deviation
+ self.target_norm_std = kwargs.get('target_norm_std', None)
+ self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
+ self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear
+ self.linear_timesteps = kwargs.get('linear_timesteps', False)
+ self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
+ self.disable_sampling = kwargs.get('disable_sampling', False)
+
+ # will cache a blank prompt or the trigger word, and unload the text encoder to cpu
+ # will make training faster and use less vram
+ self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
+ # for swapping which parameters are trained during training
+ self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
+ # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
+ self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
+ # bypass the guidance embedding for training. For open flux with guidance embedding
+ self.bypass_guidance_embedding = kwargs.get('bypass_guidance_embedding', False)
+
+
+class ModelConfig:
+ def __init__(self, **kwargs):
+ self.name_or_path: str = kwargs.get('name_or_path', None)
+ # name or path is updated on fine tuning. Keep a copy of the original
+ self.name_or_path_original: str = self.name_or_path
+ self.is_v2: bool = kwargs.get('is_v2', False)
+ self.is_xl: bool = kwargs.get('is_xl', False)
+ self.is_pixart: bool = kwargs.get('is_pixart', False)
+ self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
+ self.is_auraflow: bool = kwargs.get('is_auraflow', False)
+ self.is_v3: bool = kwargs.get('is_v3', False)
+ self.is_flux: bool = kwargs.get('is_flux', False)
+ if self.is_pixart_sigma:
+ self.is_pixart = True
+ self.use_flux_cfg = kwargs.get('use_flux_cfg', False)
+ self.is_ssd: bool = kwargs.get('is_ssd', False)
+ self.is_vega: bool = kwargs.get('is_vega', False)
+ self.is_v_pred: bool = kwargs.get('is_v_pred', False)
+ self.dtype: str = kwargs.get('dtype', 'float16')
+ self.vae_path = kwargs.get('vae_path', None)
+ self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None)
+ self._original_refiner_name_or_path = self.refiner_name_or_path
+ self.refiner_start_at = kwargs.get('refiner_start_at', 0.5)
+ self.lora_path = kwargs.get('lora_path', None)
+ # mainly for decompression loras for distilled models
+ self.assistant_lora_path = kwargs.get('assistant_lora_path', None)
+ self.inference_lora_path = kwargs.get('inference_lora_path', None)
+ self.latent_space_version = kwargs.get('latent_space_version', None)
+
+ # only for SDXL models for now
+ self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
+ self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True)
+
+ self.experimental_xl: bool = kwargs.get('experimental_xl', False)
+
+ if self.name_or_path is None:
+ raise ValueError('name_or_path must be specified')
+
+ if self.is_ssd:
+ # sed sdxl as true since it is mostly the same architecture
+ self.is_xl = True
+
+ if self.is_vega:
+ self.is_xl = True
+
+ # for text encoder quant. Only works with pixart currently
+ self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4
+ self.unet_path = kwargs.get("unet_path", None)
+ self.unet_sample_size = kwargs.get("unet_sample_size", None)
+ self.vae_device = kwargs.get("vae_device", None)
+ self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
+ self.te_device = kwargs.get("te_device", None)
+ self.te_dtype = kwargs.get("te_dtype", self.dtype)
+
+ # only for flux for now
+ self.quantize = kwargs.get("quantize", False)
+ self.low_vram = kwargs.get("low_vram", False)
+ self.attn_masking = kwargs.get("attn_masking", False)
+ if self.attn_masking and not self.is_flux:
+ raise ValueError("attn_masking is only supported with flux models currently")
+ # for targeting a specific layers
+ self.ignore_if_contains: Optional[List[str]] = kwargs.get("ignore_if_contains", None)
+ self.only_if_contains: Optional[List[str]] = kwargs.get("only_if_contains", None)
+ self.quantize_kwargs = kwargs.get("quantize_kwargs", {})
+
+ if self.ignore_if_contains is not None or self.only_if_contains is not None:
+ if not self.is_flux:
+ raise ValueError("ignore_if_contains and only_if_contains are only supported with flux models currently")
+
+
+class EMAConfig:
+ def __init__(self, **kwargs):
+ self.use_ema: bool = kwargs.get('use_ema', False)
+ self.ema_decay: float = kwargs.get('ema_decay', 0.999)
+ # feeds back the decay difference into the parameter
+ self.use_feedback: bool = kwargs.get('use_feedback', False)
+
+ # every update, the params are multiplied by this amount
+ # only use for things without a bias like lora
+ # similar to a decay in an optimizer but the opposite
+ self.param_multiplier: float = kwargs.get('param_multiplier', 1.0)
+
+
+class ReferenceDatasetConfig:
+ def __init__(self, **kwargs):
+ # can pass with a side by side pait or a folder with pos and neg folder
+ self.pair_folder: str = kwargs.get('pair_folder', None)
+ self.pos_folder: str = kwargs.get('pos_folder', None)
+ self.neg_folder: str = kwargs.get('neg_folder', None)
+
+ self.network_weight: float = float(kwargs.get('network_weight', 1.0))
+ self.pos_weight: float = float(kwargs.get('pos_weight', self.network_weight))
+ self.neg_weight: float = float(kwargs.get('neg_weight', self.network_weight))
+ # make sure they are all absolute values no negatives
+ self.pos_weight = abs(self.pos_weight)
+ self.neg_weight = abs(self.neg_weight)
+
+ self.target_class: str = kwargs.get('target_class', '')
+ self.size: int = kwargs.get('size', 512)
+
+
+class SliderTargetConfig:
+ def __init__(self, **kwargs):
+ self.target_class: str = kwargs.get('target_class', '')
+ self.positive: str = kwargs.get('positive', '')
+ self.negative: str = kwargs.get('negative', '')
+ self.multiplier: float = kwargs.get('multiplier', 1.0)
+ self.weight: float = kwargs.get('weight', 1.0)
+ self.shuffle: bool = kwargs.get('shuffle', False)
+
+
+class GuidanceConfig:
+ def __init__(self, **kwargs):
+ self.target_class: str = kwargs.get('target_class', '')
+ self.guidance_scale: float = kwargs.get('guidance_scale', 1.0)
+ self.positive_prompt: str = kwargs.get('positive_prompt', '')
+ self.negative_prompt: str = kwargs.get('negative_prompt', '')
+
+
+class SliderConfigAnchors:
+ def __init__(self, **kwargs):
+ self.prompt = kwargs.get('prompt', '')
+ self.neg_prompt = kwargs.get('neg_prompt', '')
+ self.multiplier = kwargs.get('multiplier', 1.0)
+
+
+class SliderConfig:
+ def __init__(self, **kwargs):
+ targets = kwargs.get('targets', [])
+ anchors = kwargs.get('anchors', [])
+ anchors = [SliderConfigAnchors(**anchor) for anchor in anchors]
+ self.anchors: List[SliderConfigAnchors] = anchors
+ self.resolutions: List[List[int]] = kwargs.get('resolutions', [[512, 512]])
+ self.prompt_file: str = kwargs.get('prompt_file', None)
+ self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
+ self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
+ self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
+ self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
+ self.low_ram = kwargs.get('low_ram', False)
+
+ # expand targets if shuffling
+ from toolkit.prompt_utils import get_slider_target_permutations
+ self.targets: List[SliderTargetConfig] = []
+ targets = [SliderTargetConfig(**target) for target in targets]
+ # do permutations if shuffle is true
+ print(f"Building slider targets")
+ for target in targets:
+ if target.shuffle:
+ target_permutations = get_slider_target_permutations(target, max_permutations=8)
+ self.targets = self.targets + target_permutations
+ else:
+ self.targets.append(target)
+ print(f"Built {len(self.targets)} slider targets (with permutations)")
+
+
+class DatasetConfig:
+ """
+ Dataset config for sd-datasets
+
+ """
+
+ def __init__(self, **kwargs):
+ self.type = kwargs.get('type', 'image') # sd, slider, reference
+ # will be legacy
+ self.folder_path: str = kwargs.get('folder_path', None)
+ # can be json or folder path
+ self.dataset_path: str = kwargs.get('dataset_path', None)
+
+ self.default_caption: str = kwargs.get('default_caption', None)
+ # trigger word for just this dataset
+ self.trigger_word: str = kwargs.get('trigger_word', None)
+ random_triggers = kwargs.get('random_triggers', [])
+ # if they are a string, load them from a file
+ if isinstance(random_triggers, str) and os.path.exists(random_triggers):
+ with open(random_triggers, 'r') as f:
+ random_triggers = f.read().splitlines()
+ # remove empty lines
+ random_triggers = [line for line in random_triggers if line.strip() != '']
+ self.random_triggers: List[str] = random_triggers
+ self.random_triggers_max: int = kwargs.get('random_triggers_max', 1)
+ self.caption_ext: str = kwargs.get('caption_ext', None)
+ self.random_scale: bool = kwargs.get('random_scale', False)
+ self.random_crop: bool = kwargs.get('random_crop', False)
+ self.resolution: int = kwargs.get('resolution', 512)
+ self.scale: float = kwargs.get('scale', 1.0)
+ self.buckets: bool = kwargs.get('buckets', True)
+ self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
+ self.is_reg: bool = kwargs.get('is_reg', False)
+ self.network_weight: float = float(kwargs.get('network_weight', 1.0))
+ self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
+ self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
+ self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
+ self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
+ self.flip_x: bool = kwargs.get('flip_x', False)
+ self.flip_y: bool = kwargs.get('flip_y', False)
+ self.augments: List[str] = kwargs.get('augments', [])
+ self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
+ # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
+ self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
+ self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
+ self.mask_path: str = kwargs.get('mask_path',
+ None) # focus mask (black and white. White has higher loss than black)
+ self.unconditional_path: str = kwargs.get('unconditional_path',
+ None) # path where matching unconditional images are located
+ self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
+ self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
+ self.poi: Union[str, None] = kwargs.get('poi',
+ None) # if one is set and in json data, will be used as auto crop scale point of interes
+ self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
+ # cache latents will store them in memory
+ self.cache_latents: bool = kwargs.get('cache_latents', False)
+ # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
+ self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
+ self.cache_clip_vision_to_disk: bool = kwargs.get('cache_clip_vision_to_disk', False)
+
+ self.standardize_images: bool = kwargs.get('standardize_images', False)
+
+ # https://albumentations.ai/docs/api_reference/augmentations/transforms
+ # augmentations are returned as a separate image and cannot currently be cached
+ self.augmentations: List[dict] = kwargs.get('augmentations', None)
+ self.shuffle_augmentations: bool = kwargs.get('shuffle_augmentations', False)
+
+ has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
+
+ if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
+ print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
+ self.cache_latents = False
+ self.cache_latents_to_disk = False
+
+ # legacy compatability
+ legacy_caption_type = kwargs.get('caption_type', None)
+ if legacy_caption_type:
+ self.caption_ext = legacy_caption_type
+ self.caption_type = self.caption_ext
+ self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted')
+
+ # ip adapter / reference dataset
+ self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
+ # get the clip image randomly from the same folder as the image. Useful for folder grouped pairs.
+ self.clip_image_from_same_folder: bool = kwargs.get('clip_image_from_same_folder', False)
+ self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None)
+ self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False)
+ self.replacements: List[str] = kwargs.get('replacements', [])
+ self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0)
+
+ self.num_workers: int = kwargs.get('num_workers', 2)
+ self.prefetch_factor: int = kwargs.get('prefetch_factor', 2)
+ self.extra_values: List[float] = kwargs.get('extra_values', [])
+ self.square_crop: bool = kwargs.get('square_crop', False)
+ # apply same augmentations to control images. Usually want this true unless special case
+ self.replay_transforms: bool = kwargs.get('replay_transforms', True)
+
+
+def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
+ """
+ This just splits up the datasets by resolutions so you dont have to do it manually
+ :param raw_config:
+ :return:
+ """
+ # split up datasets by resolutions
+ new_config = []
+ for dataset in raw_config:
+ resolution = dataset.get('resolution', 512)
+ if isinstance(resolution, list):
+ resolution_list = resolution
+ else:
+ resolution_list = [resolution]
+ for res in resolution_list:
+ dataset_copy = dataset.copy()
+ dataset_copy['resolution'] = res
+ new_config.append(dataset_copy)
+ return new_config
+
+
+class GenerateImageConfig:
+ def __init__(
+ self,
+ prompt: str = '',
+ prompt_2: Optional[str] = None,
+ width: int = 512,
+ height: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: str = '',
+ negative_prompt_2: Optional[str] = None,
+ seed: int = -1,
+ network_multiplier: float = 1.0,
+ guidance_rescale: float = 0.0,
+ # the tag [time] will be replaced with milliseconds since epoch
+ output_path: str = None, # full image path
+ output_folder: str = None, # folder to save image in if output_path is not specified
+ output_ext: str = ImgExt, # extension to save image as if output_path is not specified
+ output_tail: str = '', # tail to add to output filename
+ add_prompt_file: bool = False, # add a prompt file with generated image
+ adapter_image_path: str = None, # path to adapter image
+ adapter_conditioning_scale: float = 1.0, # scale for adapter conditioning
+ latents: Union[torch.Tensor | None] = None, # input latent to start with,
+ extra_kwargs: dict = None, # extra data to save with prompt file
+ refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
+ extra_values: List[float] = None, # extra values to save with prompt file
+ logger: Optional[EmptyLogger] = None,
+ ):
+ self.width: int = width
+ self.height: int = height
+ self.num_inference_steps: int = num_inference_steps
+ self.guidance_scale: float = guidance_scale
+ self.guidance_rescale: float = guidance_rescale
+ self.prompt: str = prompt
+ self.prompt_2: str = prompt_2
+ self.negative_prompt: str = negative_prompt
+ self.negative_prompt_2: str = negative_prompt_2
+ self.latents: Union[torch.Tensor | None] = latents
+
+ self.output_path: str = output_path
+ self.seed: int = seed
+ if self.seed == -1:
+ # generate random one
+ self.seed = random.randint(0, 2 ** 32 - 1)
+ self.network_multiplier: float = network_multiplier
+ self.output_folder: str = output_folder
+ self.output_ext: str = output_ext
+ self.add_prompt_file: bool = add_prompt_file
+ self.output_tail: str = output_tail
+ self.gen_time: int = int(time.time() * 1000)
+ self.adapter_image_path: str = adapter_image_path
+ self.adapter_conditioning_scale: float = adapter_conditioning_scale
+ self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
+ self.refiner_start_at = refiner_start_at
+ self.extra_values = extra_values if extra_values is not None else []
+
+ # prompt string will override any settings above
+ self._process_prompt_string()
+
+ # handle dual text encoder prompts if nothing passed
+ if negative_prompt_2 is None:
+ self.negative_prompt_2 = negative_prompt
+
+ if prompt_2 is None:
+ self.prompt_2 = self.prompt
+
+ # parse prompt paths
+ if self.output_path is None and self.output_folder is None:
+ raise ValueError('output_path or output_folder must be specified')
+ elif self.output_path is not None:
+ self.output_folder = os.path.dirname(self.output_path)
+ self.output_ext = os.path.splitext(self.output_path)[1][1:]
+ self.output_filename_no_ext = os.path.splitext(os.path.basename(self.output_path))[0]
+
+ else:
+ self.output_filename_no_ext = '[time]_[count]'
+ if len(self.output_tail) > 0:
+ self.output_filename_no_ext += '_' + self.output_tail
+ self.output_path = os.path.join(self.output_folder, self.output_filename_no_ext + '.' + self.output_ext)
+
+ # adjust height
+ self.height = max(64, self.height - self.height % 8) # round to divisible by 8
+ self.width = max(64, self.width - self.width % 8) # round to divisible by 8
+
+ self.logger = logger
+
+ def set_gen_time(self, gen_time: int = None):
+ if gen_time is not None:
+ self.gen_time = gen_time
+ else:
+ self.gen_time = int(time.time() * 1000)
+
+ def _get_path_no_ext(self, count: int = 0, max_count=0):
+ # zero pad count
+ count_str = str(count).zfill(len(str(max_count)))
+ # replace [time] with gen time
+ filename = self.output_filename_no_ext.replace('[time]', str(self.gen_time))
+ # replace [count] with count
+ filename = filename.replace('[count]', count_str)
+ return filename
+
+ def get_image_path(self, count: int = 0, max_count=0):
+ filename = self._get_path_no_ext(count, max_count)
+ ext = self.output_ext
+ # if it does not start with a dot add one
+ if ext[0] != '.':
+ ext = '.' + ext
+ filename += ext
+ # join with folder
+ return os.path.join(self.output_folder, filename)
+
+ def get_prompt_path(self, count: int = 0, max_count=0):
+ filename = self._get_path_no_ext(count, max_count)
+ filename += '.txt'
+ # join with folder
+ return os.path.join(self.output_folder, filename)
+
+ def save_image(self, image, count: int = 0, max_count=0):
+ # make parent dirs
+ os.makedirs(self.output_folder, exist_ok=True)
+ self.set_gen_time()
+ # TODO save image gen header info for A1111 and us, our seeds probably wont match
+ image.save(self.get_image_path(count, max_count))
+ # do prompt file
+ if self.add_prompt_file:
+ self.save_prompt_file(count, max_count)
+
+ def save_prompt_file(self, count: int = 0, max_count=0):
+ # save prompt file
+ with open(self.get_prompt_path(count, max_count), 'w') as f:
+ prompt = self.prompt
+ if self.prompt_2 is not None:
+ prompt += ' --p2 ' + self.prompt_2
+ if self.negative_prompt is not None:
+ prompt += ' --n ' + self.negative_prompt
+ if self.negative_prompt_2 is not None:
+ prompt += ' --n2 ' + self.negative_prompt_2
+ prompt += ' --w ' + str(self.width)
+ prompt += ' --h ' + str(self.height)
+ prompt += ' --seed ' + str(self.seed)
+ prompt += ' --cfg ' + str(self.guidance_scale)
+ prompt += ' --steps ' + str(self.num_inference_steps)
+ prompt += ' --m ' + str(self.network_multiplier)
+ prompt += ' --gr ' + str(self.guidance_rescale)
+
+ # get gen info
+ f.write(self.prompt)
+
+ def _process_prompt_string(self):
+ # we will try to support all sd-scripts where we can
+
+ # FROM SD-SCRIPTS
+ # --n Treat everything until the next option as a negative prompt.
+ # --w Specify the width of the generated image.
+ # --h Specify the height of the generated image.
+ # --d Specify the seed for the generated image.
+ # --l Specify the CFG scale for the generated image.
+ # --s Specify the number of steps during generation.
+
+ # OURS and some QOL additions
+ # --m Specify the network multiplier for the generated image.
+ # --p2 Prompt for the second text encoder (SDXL only)
+ # --n2 Negative prompt for the second text encoder (SDXL only)
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
+
+ # --seed Specify the seed for the generated image same as --d
+ # --cfg Specify the CFG scale for the generated image same as --l
+ # --steps Specify the number of steps during generation same as --s
+ # --network_multiplier Specify the network multiplier for the generated image same as --m
+
+ # process prompt string and update values if it has some
+ if self.prompt is not None and len(self.prompt) > 0:
+ # process prompt string
+ prompt = self.prompt
+ prompt = prompt.strip()
+ p_split = prompt.split('--')
+ self.prompt = p_split[0].strip()
+
+ if len(p_split) > 1:
+ for split in p_split[1:]:
+ # allows multi char flags
+ flag = split.split(' ')[0].strip()
+ content = split[len(flag):].strip()
+ if flag == 'p2':
+ self.prompt_2 = content
+ elif flag == 'n':
+ self.negative_prompt = content
+ elif flag == 'n2':
+ self.negative_prompt_2 = content
+ elif flag == 'w':
+ self.width = int(content)
+ elif flag == 'h':
+ self.height = int(content)
+ elif flag == 'd':
+ self.seed = int(content)
+ elif flag == 'seed':
+ self.seed = int(content)
+ elif flag == 'l':
+ self.guidance_scale = float(content)
+ elif flag == 'cfg':
+ self.guidance_scale = float(content)
+ elif flag == 's':
+ self.num_inference_steps = int(content)
+ elif flag == 'steps':
+ self.num_inference_steps = int(content)
+ elif flag == 'm':
+ self.network_multiplier = float(content)
+ elif flag == 'network_multiplier':
+ self.network_multiplier = float(content)
+ elif flag == 'gr':
+ self.guidance_rescale = float(content)
+ elif flag == 'a':
+ self.adapter_conditioning_scale = float(content)
+ elif flag == 'ref':
+ self.refiner_start_at = float(content)
+ elif flag == 'ev':
+ # split by comma
+ self.extra_values = [float(val) for val in content.split(',')]
+ elif flag == 'extra_values':
+ # split by comma
+ self.extra_values = [float(val) for val in content.split(',')]
+
+ def post_process_embeddings(
+ self,
+ conditional_prompt_embeds: PromptEmbeds,
+ unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
+ ):
+ # this is called after prompt embeds are encoded. We can override them in the future here
+ pass
+
+ def log_image(self, image, count: int = 0, max_count=0):
+ if self.logger is None:
+ return
+
+ self.logger.log_image(image, count, self.prompt)
+
+
+def validate_configs(
+ train_config: TrainConfig,
+ model_config: ModelConfig,
+ save_config: SaveConfig,
+):
+ if model_config.is_flux:
+ if save_config.save_format != 'diffusers':
+ # make it diffusers
+ save_config.save_format = 'diffusers'
+ if model_config.use_flux_cfg:
+ # bypass the embedding
+ train_config.bypass_guidance_embedding = True
diff --git a/toolkit/cuda_malloc.py b/toolkit/cuda_malloc.py
new file mode 100644
index 0000000000000000000000000000000000000000..239b9666a83ea3f3838737b725902c6590ea19bc
--- /dev/null
+++ b/toolkit/cuda_malloc.py
@@ -0,0 +1,93 @@
+# ref comfy ui
+import os
+import importlib.util
+
+
+# Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
+def get_gpu_names():
+ if os.name == 'nt':
+ import ctypes
+
+ # Define necessary C structures and types
+ class DISPLAY_DEVICEA(ctypes.Structure):
+ _fields_ = [
+ ('cb', ctypes.c_ulong),
+ ('DeviceName', ctypes.c_char * 32),
+ ('DeviceString', ctypes.c_char * 128),
+ ('StateFlags', ctypes.c_ulong),
+ ('DeviceID', ctypes.c_char * 128),
+ ('DeviceKey', ctypes.c_char * 128)
+ ]
+
+ # Load user32.dll
+ user32 = ctypes.windll.user32
+
+ # Call EnumDisplayDevicesA
+ def enum_display_devices():
+ device_info = DISPLAY_DEVICEA()
+ device_info.cb = ctypes.sizeof(device_info)
+ device_index = 0
+ gpu_names = set()
+
+ while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
+ device_index += 1
+ gpu_names.add(device_info.DeviceString.decode('utf-8'))
+ return gpu_names
+
+ return enum_display_devices()
+ else:
+ return set()
+
+
+blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950",
+ "GeForce 945M",
+ "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745",
+ "Quadro K620",
+ "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
+ "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000",
+ "Quadro M5500", "Quadro M6000",
+ "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
+ "GeForce GTX 1650", "GeForce GTX 1630"
+ }
+
+
+def cuda_malloc_supported():
+ try:
+ names = get_gpu_names()
+ except:
+ names = set()
+ for x in names:
+ if "NVIDIA" in x:
+ for b in blacklist:
+ if b in x:
+ return False
+ return True
+
+
+cuda_malloc = False
+
+if not cuda_malloc:
+ try:
+ version = ""
+ torch_spec = importlib.util.find_spec("torch")
+ for folder in torch_spec.submodule_search_locations:
+ ver_file = os.path.join(folder, "version.py")
+ if os.path.isfile(ver_file):
+ spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ version = module.__version__
+ if int(version[0]) >= 2: # enable by default for torch version 2.0 and up
+ cuda_malloc = cuda_malloc_supported()
+ except:
+ pass
+
+if cuda_malloc:
+ env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
+ if env_var is None:
+ env_var = "backend:cudaMallocAsync"
+ else:
+ env_var += ",backend:cudaMallocAsync"
+
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
+ print("CUDA Malloc Async Enabled")
diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a4df4be4aa406a5bb08bbd0ae5495363490da7
--- /dev/null
+++ b/toolkit/custom_adapter.py
@@ -0,0 +1,1026 @@
+import math
+import torch
+import sys
+
+from PIL import Image
+from torch.nn import Parameter
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, CLIPTextModel, \
+ CLIPTokenizer, T5Tokenizer
+
+from toolkit.models.clip_fusion import CLIPFusionModule
+from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
+from toolkit.models.ilora import InstantLoRAModule
+from toolkit.models.single_value_adapter import SingleValueAdapter
+from toolkit.models.te_adapter import TEAdapter
+from toolkit.models.te_aug_adapter import TEAugAdapter
+from toolkit.models.vd_adapter import VisionDirectAdapter
+from toolkit.models.redux import ReduxImageEncoder
+from toolkit.paths import REPOS_ROOT
+from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
+from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
+from toolkit.train_tools import get_torch_dtype
+from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
+import random
+
+sys.path.append(REPOS_ROOT)
+from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
+from collections import OrderedDict
+from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
+ AttnProcessor2_0
+from ipadapter.ip_adapter.ip_adapter import ImageProjModel
+from ipadapter.ip_adapter.resampler import Resampler
+from toolkit.config_modules import AdapterConfig, AdapterTypes
+from toolkit.prompt_utils import PromptEmbeds
+import weakref
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ CLIPVisionModel,
+ AutoImageProcessor,
+ ConvNextModel,
+ ConvNextForImageClassification,
+ ConvNextImageProcessor,
+ UMT5EncoderModel, LlamaTokenizerFast
+)
+from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
+
+from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
+
+from transformers import ViTFeatureExtractor, ViTForImageClassification
+
+import torch.nn.functional as F
+
+
+class CustomAdapter(torch.nn.Module):
+ def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
+ super().__init__()
+ self.config = adapter_config
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.device = self.sd_ref().unet.device
+ self.image_processor: CLIPImageProcessor = None
+ self.input_size = 224
+ self.adapter_type: AdapterTypes = self.config.type
+ self.current_scale = 1.0
+ self.is_active = True
+ self.flag_word = "fla9wor0"
+ self.is_unconditional_run = False
+
+ self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
+
+ self.fuse_module: FuseModule = None
+
+ self.lora: None = None
+
+ self.position_ids: Optional[List[int]] = None
+
+ self.num_control_images = 1
+ self.token_mask: Optional[torch.Tensor] = None
+
+ # setup clip
+ self.setup_clip()
+ # add for dataloader
+ self.clip_image_processor = self.image_processor
+
+ self.clip_fusion_module: CLIPFusionModule = None
+ self.ilora_module: InstantLoRAModule = None
+
+ self.te: Union[T5EncoderModel, CLIPTextModel] = None
+ self.tokenizer: CLIPTokenizer = None
+ self.te_adapter: TEAdapter = None
+ self.te_augmenter: TEAugAdapter = None
+ self.vd_adapter: VisionDirectAdapter = None
+ self.single_value_adapter: SingleValueAdapter = None
+ self.redux_adapter: ReduxImageEncoder = None
+
+ self.conditional_embeds: Optional[torch.Tensor] = None
+ self.unconditional_embeds: Optional[torch.Tensor] = None
+
+ self.setup_adapter()
+
+ if self.adapter_type == 'photo_maker':
+ # try to load from our name_or_path
+ if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
+ self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
+ # add the trigger word to the tokenizer
+ if isinstance(self.sd_ref().tokenizer, list):
+ for tokenizer in self.sd_ref().tokenizer:
+ tokenizer.add_tokens([self.flag_word], special_tokens=True)
+ else:
+ self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
+ elif self.config.name_or_path is not None:
+ loaded_state_dict = load_custom_adapter_model(
+ self.config.name_or_path,
+ self.sd_ref().device,
+ dtype=self.sd_ref().dtype,
+ )
+ self.load_state_dict(loaded_state_dict, strict=False)
+
+ def setup_adapter(self):
+ torch_dtype = get_torch_dtype(self.sd_ref().dtype)
+ if self.adapter_type == 'photo_maker':
+ sd = self.sd_ref()
+ embed_dim = sd.unet.config['cross_attention_dim']
+ self.fuse_module = FuseModule(embed_dim)
+ elif self.adapter_type == 'clip_fusion':
+ sd = self.sd_ref()
+ embed_dim = sd.unet.config['cross_attention_dim']
+
+ vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
+ if self.config.image_encoder_arch == 'clip':
+ vision_tokens = vision_tokens + 1
+ self.clip_fusion_module = CLIPFusionModule(
+ text_hidden_size=embed_dim,
+ text_tokens=77,
+ vision_hidden_size=self.vision_encoder.config.hidden_size,
+ vision_tokens=vision_tokens
+ )
+ elif self.adapter_type == 'ilora':
+ vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
+ if self.config.image_encoder_arch == 'clip':
+ vision_tokens = vision_tokens + 1
+
+ vision_hidden_size = self.vision_encoder.config.hidden_size
+
+ if self.config.clip_layer == 'image_embeds':
+ vision_tokens = 1
+ vision_hidden_size = self.vision_encoder.config.projection_dim
+
+ self.ilora_module = InstantLoRAModule(
+ vision_tokens=vision_tokens,
+ vision_hidden_size=vision_hidden_size,
+ head_dim=self.config.head_dim,
+ num_heads=self.config.num_heads,
+ sd=self.sd_ref(),
+ config=self.config
+ )
+ elif self.adapter_type == 'text_encoder':
+ if self.config.text_encoder_arch == 't5':
+ te_kwargs = {}
+ # te_kwargs['load_in_4bit'] = True
+ # te_kwargs['load_in_8bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+
+ self.te = T5EncoderModel.from_pretrained(
+ self.config.text_encoder_path,
+ torch_dtype=torch_dtype,
+ **te_kwargs
+ )
+
+ # self.te.to = lambda *args, **kwargs: None
+ self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path)
+ elif self.config.text_encoder_arch == 'pile-t5':
+ te_kwargs = {}
+ # te_kwargs['load_in_4bit'] = True
+ # te_kwargs['load_in_8bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+
+ self.te = UMT5EncoderModel.from_pretrained(
+ self.config.text_encoder_path,
+ torch_dtype=torch_dtype,
+ **te_kwargs
+ )
+
+ # self.te.to = lambda *args, **kwargs: None
+ self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+ elif self.config.text_encoder_arch == 'clip':
+ self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device,
+ dtype=torch_dtype)
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path)
+ else:
+ raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
+
+ self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
+ elif self.adapter_type == 'te_augmenter':
+ self.te_augmenter = TEAugAdapter(self, self.sd_ref())
+ elif self.adapter_type == 'vision_direct':
+ self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
+ elif self.adapter_type == 'single_value':
+ self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
+ elif self.adapter_type == 'redux':
+ vision_hidden_size = self.vision_encoder.config.hidden_size
+ self.redux_adapter = ReduxImageEncoder(vision_hidden_size, 4096, self.device, torch_dtype)
+ else:
+ raise ValueError(f"unknown adapter type: {self.adapter_type}")
+
+ def forward(self, *args, **kwargs):
+ # dont think this is used
+ # if self.adapter_type == 'photo_maker':
+ # id_pixel_values = args[0]
+ # prompt_embeds: PromptEmbeds = args[1]
+ # class_tokens_mask = args[2]
+ #
+ # grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
+ #
+ # with torch.set_grad_enabled(grads_on_image_encoder):
+ # id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
+ #
+ # if not grads_on_image_encoder:
+ # id_embeds = id_embeds.detach()
+ #
+ # prompt_embeds = prompt_embeds.detach()
+ #
+ # updated_prompt_embeds = self.fuse_module(
+ # prompt_embeds, id_embeds, class_tokens_mask
+ # )
+ #
+ # return updated_prompt_embeds
+ # else:
+ raise NotImplementedError
+
+ def setup_clip(self):
+ adapter_config = self.config
+ sd = self.sd_ref()
+ if self.config.type == "text_encoder" or self.config.type == "single_value":
+ return
+ if self.config.type == 'photo_maker':
+ try:
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
+ except EnvironmentError:
+ self.image_processor = CLIPImageProcessor()
+ if self.config.image_encoder_path is None:
+ self.vision_encoder = PhotoMakerCLIPEncoder()
+ else:
+ self.vision_encoder = PhotoMakerCLIPEncoder.from_pretrained(self.config.image_encoder_path)
+ elif self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
+ try:
+ self.image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.image_processor = CLIPImageProcessor()
+ self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ adapter_config.image_encoder_path,
+ ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'siglip':
+ from transformers import SiglipImageProcessor, SiglipVisionModel
+ try:
+ self.image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.image_processor = SiglipImageProcessor()
+ self.vision_encoder = SiglipVisionModel.from_pretrained(
+ adapter_config.image_encoder_path,
+ ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'pixtral':
+ self.image_processor = PixtralVisionImagePreprocessorCompatible(
+ max_image_size=self.config.pixtral_max_image_size,
+ )
+ self.vision_encoder = PixtralVisionEncoderCompatible.from_pretrained(
+ adapter_config.image_encoder_path,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'vit':
+ try:
+ self.image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.image_processor = ViTFeatureExtractor()
+ self.vision_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'safe':
+ try:
+ self.image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.image_processor = SAFEImageProcessor()
+ self.vision_encoder = SAFEVisionModel(
+ in_channels=3,
+ num_tokens=self.config.safe_tokens,
+ num_vectors=sd.unet.config['cross_attention_dim'],
+ reducer_channels=self.config.safe_reducer_channels,
+ channels=self.config.safe_channels,
+ downscale_factor=8
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'convnext':
+ try:
+ self.image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ print(f"could not load image processor from {adapter_config.image_encoder_path}")
+ self.image_processor = ConvNextImageProcessor(
+ size=320,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ )
+ self.vision_encoder = ConvNextForImageClassification.from_pretrained(
+ adapter_config.image_encoder_path,
+ use_safetensors=True,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'vit-hybrid':
+ try:
+ self.image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ print(f"could not load image processor from {adapter_config.image_encoder_path}")
+ self.image_processor = ViTHybridImageProcessor(
+ size=320,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ )
+ self.vision_encoder = ViTHybridForImageClassification.from_pretrained(
+ adapter_config.image_encoder_path,
+ use_safetensors=True,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ else:
+ raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
+
+ self.input_size = self.vision_encoder.config.image_size
+
+ if self.config.quad_image: # 4x4 image
+ # self.clip_image_processor.config
+ # We do a 3x downscale of the image, so we need to adjust the input size
+ preprocessor_input_size = self.vision_encoder.config.image_size * 2
+
+ # update the preprocessor so images come in at the right size
+ if 'height' in self.image_processor.size:
+ self.image_processor.size['height'] = preprocessor_input_size
+ self.image_processor.size['width'] = preprocessor_input_size
+ elif hasattr(self.image_processor, 'crop_size'):
+ self.image_processor.size['shortest_edge'] = preprocessor_input_size
+ self.image_processor.crop_size['height'] = preprocessor_input_size
+ self.image_processor.crop_size['width'] = preprocessor_input_size
+
+ if self.config.image_encoder_arch == 'clip+':
+ # self.image_processor.config
+ # We do a 3x downscale of the image, so we need to adjust the input size
+ preprocessor_input_size = self.vision_encoder.config.image_size * 4
+
+ # update the preprocessor so images come in at the right size
+ self.image_processor.size['shortest_edge'] = preprocessor_input_size
+ self.image_processor.crop_size['height'] = preprocessor_input_size
+ self.image_processor.crop_size['width'] = preprocessor_input_size
+
+ self.preprocessor = CLIPImagePreProcessor(
+ input_size=preprocessor_input_size,
+ clip_input_size=self.vision_encoder.config.image_size,
+ )
+ if 'height' in self.image_processor.size:
+ self.input_size = self.image_processor.size['height']
+ else:
+ self.input_size = self.image_processor.crop_size['height']
+
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+ strict = False
+ if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict:
+ # we are loading pure clip weights.
+ self.vision_encoder.load_state_dict(state_dict, strict=strict)
+
+ if 'lora_weights' in state_dict:
+ # todo add LoRA
+ # self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
+ # self.sd_ref().pipeline.fuse_lora()
+ pass
+ if 'clip_fusion' in state_dict:
+ self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
+ if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
+ self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
+ # check to see if the fuse weights are there
+ fuse_weights = {}
+ for k, v in state_dict['id_encoder'].items():
+ if k.startswith('fuse_module'):
+ k = k.replace('fuse_module.', '')
+ fuse_weights[k] = v
+ if len(fuse_weights) > 0:
+ try:
+ self.fuse_module.load_state_dict(fuse_weights, strict=strict)
+ except Exception as e:
+
+ print(e)
+ # force load it
+ print(f"force loading fuse module as it did not match")
+ current_state_dict = self.fuse_module.state_dict()
+ for k, v in fuse_weights.items():
+ if len(v.shape) == 1:
+ current_state_dict[k] = v[:current_state_dict[k].shape[0]]
+ elif len(v.shape) == 2:
+ current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
+ elif len(v.shape) == 3:
+ current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
+ :current_state_dict[k].shape[2]]
+ elif len(v.shape) == 4:
+ current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
+ :current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
+ else:
+ raise ValueError(f"unknown shape: {v.shape}")
+ self.fuse_module.load_state_dict(current_state_dict, strict=strict)
+
+ if 'te_adapter' in state_dict:
+ self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
+
+ if 'te_augmenter' in state_dict:
+ self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
+
+ if 'vd_adapter' in state_dict:
+ self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
+ if 'dvadapter' in state_dict:
+ self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False)
+
+ if 'sv_adapter' in state_dict:
+ self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
+
+ if 'vision_encoder' in state_dict:
+ self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
+
+ if 'fuse_module' in state_dict:
+ self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
+
+ if 'ilora' in state_dict:
+ try:
+ self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
+ except Exception as e:
+ print(e)
+ if 'redux_up' in state_dict:
+ # state dict is seperated. so recombine it
+ new_dict = {}
+ for k, v in state_dict.items():
+ for k2, v2 in v.items():
+ new_dict[k + '.' + k2] = v2
+ self.redux_adapter.load_state_dict(new_dict, strict=True)
+
+ pass
+
+ def state_dict(self) -> OrderedDict:
+ state_dict = OrderedDict()
+ if self.config.train_only_image_encoder:
+ return self.vision_encoder.state_dict()
+
+ if self.adapter_type == 'photo_maker':
+ if self.config.train_image_encoder:
+ state_dict["id_encoder"] = self.vision_encoder.state_dict()
+
+ state_dict["fuse_module"] = self.fuse_module.state_dict()
+
+ # todo save LoRA
+ return state_dict
+
+ elif self.adapter_type == 'clip_fusion':
+ if self.config.train_image_encoder:
+ state_dict["vision_encoder"] = self.vision_encoder.state_dict()
+ state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
+ return state_dict
+ elif self.adapter_type == 'text_encoder':
+ state_dict["te_adapter"] = self.te_adapter.state_dict()
+ return state_dict
+ elif self.adapter_type == 'te_augmenter':
+ if self.config.train_image_encoder:
+ state_dict["vision_encoder"] = self.vision_encoder.state_dict()
+ state_dict["te_augmenter"] = self.te_augmenter.state_dict()
+ return state_dict
+ elif self.adapter_type == 'vision_direct':
+ state_dict["dvadapter"] = self.vd_adapter.state_dict()
+ # if self.config.train_image_encoder: # always return vision encoder
+ state_dict["vision_encoder"] = self.vision_encoder.state_dict()
+ return state_dict
+ elif self.adapter_type == 'single_value':
+ state_dict["sv_adapter"] = self.single_value_adapter.state_dict()
+ return state_dict
+ elif self.adapter_type == 'ilora':
+ if self.config.train_image_encoder:
+ state_dict["vision_encoder"] = self.vision_encoder.state_dict()
+ state_dict["ilora"] = self.ilora_module.state_dict()
+ return state_dict
+ elif self.adapter_type == 'redux':
+ d = self.redux_adapter.state_dict()
+ for k, v in d.items():
+ state_dict[k] = v
+ return state_dict
+ else:
+ raise NotImplementedError
+
+ def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False):
+ if self.adapter_type == 'single_value':
+ if is_unconditional:
+ self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
+ else:
+ self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
+
+
+ def condition_prompt(
+ self,
+ prompt: Union[List[str], str],
+ is_unconditional: bool = False,
+ ):
+ if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux':
+ return prompt
+ elif self.adapter_type == 'text_encoder':
+ # todo allow for training
+ with torch.no_grad():
+ # encode and save the embeds
+ if is_unconditional:
+ self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
+ else:
+ self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
+ return prompt
+ elif self.adapter_type == 'photo_maker':
+ if is_unconditional:
+ return prompt
+ else:
+
+ with torch.no_grad():
+ was_list = isinstance(prompt, list)
+ if not was_list:
+ prompt_list = [prompt]
+ else:
+ prompt_list = prompt
+
+ new_prompt_list = []
+ token_mask_list = []
+
+ for prompt in prompt_list:
+
+ our_class = None
+ # find a class in the prompt
+ prompt_parts = prompt.split(' ')
+ prompt_parts = [p.strip().lower() for p in prompt_parts if len(p) > 0]
+
+ new_prompt_parts = []
+ tokened_prompt_parts = []
+ for idx, prompt_part in enumerate(prompt_parts):
+ new_prompt_parts.append(prompt_part)
+ tokened_prompt_parts.append(prompt_part)
+ if prompt_part in self.config.class_names:
+ our_class = prompt_part
+ # add the flag word
+ tokened_prompt_parts.append(self.flag_word)
+
+ if self.num_control_images > 1:
+ # add the rest
+ for _ in range(self.num_control_images - 1):
+ new_prompt_parts.extend(prompt_parts[idx + 1:])
+
+ # add the rest
+ tokened_prompt_parts.extend(prompt_parts[idx + 1:])
+ new_prompt_parts.extend(prompt_parts[idx + 1:])
+
+ break
+
+ prompt = " ".join(new_prompt_parts)
+ tokened_prompt = " ".join(tokened_prompt_parts)
+
+ if our_class is None:
+ # add the first one to the front of the prompt
+ tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
+ our_class = self.config.class_names[0]
+ prompt = " ".join(
+ [self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
+
+ # add the prompt to the list
+ new_prompt_list.append(prompt)
+
+ # tokenize them with just the first tokenizer
+ tokenizer = self.sd_ref().tokenizer
+ if isinstance(tokenizer, list):
+ tokenizer = tokenizer[0]
+
+ flag_token = tokenizer.convert_tokens_to_ids(self.flag_word)
+
+ tokenized_prompt = tokenizer.encode(prompt)
+ tokenized_tokened_prompt = tokenizer.encode(tokened_prompt)
+
+ flag_idx = tokenized_tokened_prompt.index(flag_token)
+
+ class_token = tokenized_prompt[flag_idx - 1]
+
+ boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
+ boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
+ boolean_mask = boolean_mask.to(self.device)
+ # zero pad it to 77
+ boolean_mask = F.pad(boolean_mask, (0, 77 - boolean_mask.shape[0]), value=False)
+
+ token_mask_list.append(boolean_mask)
+
+ self.token_mask = torch.cat(token_mask_list, dim=0).to(self.device)
+
+ prompt_list = new_prompt_list
+
+ if not was_list:
+ prompt = prompt_list[0]
+ else:
+ prompt = prompt_list
+
+ return prompt
+
+ else:
+ return prompt
+
+ def condition_encoded_embeds(
+ self,
+ tensors_0_1: torch.Tensor,
+ prompt_embeds: PromptEmbeds,
+ is_training=False,
+ has_been_preprocessed=False,
+ is_unconditional=False,
+ quad_count=4,
+ is_generating_samples=False,
+ ) -> PromptEmbeds:
+ if self.adapter_type == 'text_encoder' and is_generating_samples:
+ # replace the prompt embed with ours
+ if is_unconditional:
+ return self.unconditional_embeds.clone()
+ return self.conditional_embeds.clone()
+
+ if self.adapter_type == 'ilora':
+ return prompt_embeds
+
+ if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux':
+ if is_unconditional:
+ # we dont condition the negative embeds for photo maker
+ return prompt_embeds.clone()
+ with torch.no_grad():
+ # on training the clip image is created in the dataloader
+ if not has_been_preprocessed:
+ # tensors should be 0-1
+ if tensors_0_1.ndim == 3:
+ tensors_0_1 = tensors_0_1.unsqueeze(0)
+ # training tensors are 0 - 1
+ tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
+ # if images are out of this range throw error
+ if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
+ raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
+ tensors_0_1.min(), tensors_0_1.max()
+ ))
+ clip_image = self.image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ do_convert_rgb=True
+ ).pixel_values
+ else:
+ clip_image = tensors_0_1
+ clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
+
+ if self.config.quad_image:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = clip_image.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ to_cat = []
+ for i, ci in enumerate([ci1, ci2, ci3, ci4]):
+ if i < quad_count:
+ to_cat.append(ci)
+ else:
+ break
+
+ clip_image = torch.cat(to_cat, dim=0).detach()
+
+ if self.adapter_type == 'photo_maker':
+ # Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
+ clip_image = clip_image.unsqueeze(1)
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.vision_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ id_embeds = self.vision_encoder(
+ clip_image,
+ do_projection2=isinstance(self.sd_ref().text_encoder, list),
+ )
+ else:
+ with torch.no_grad():
+ self.vision_encoder.eval()
+ id_embeds = self.vision_encoder(
+ clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
+ ).detach()
+
+ prompt_embeds.text_embeds = self.fuse_module(
+ prompt_embeds.text_embeds,
+ id_embeds,
+ self.token_mask
+ )
+ return prompt_embeds
+ elif self.adapter_type == 'clip_fusion':
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.vision_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ id_embeds = self.vision_encoder(
+ clip_image,
+ output_hidden_states=True,
+ )
+ else:
+ with torch.no_grad():
+ self.vision_encoder.eval()
+ id_embeds = self.vision_encoder(
+ clip_image, output_hidden_states=True
+ )
+
+ img_embeds = id_embeds['last_hidden_state']
+
+ if self.config.quad_image:
+ # get the outputs of the quat
+ chunks = img_embeds.chunk(quad_count, dim=0)
+ chunk_sum = torch.zeros_like(chunks[0])
+ for chunk in chunks:
+ chunk_sum = chunk_sum + chunk
+ # get the mean of them
+
+ img_embeds = chunk_sum / quad_count
+
+ if not is_training or not self.config.train_image_encoder:
+ img_embeds = img_embeds.detach()
+
+ prompt_embeds.text_embeds = self.clip_fusion_module(
+ prompt_embeds.text_embeds,
+ img_embeds
+ )
+ return prompt_embeds
+
+ elif self.adapter_type == 'redux':
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.vision_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ id_embeds = self.vision_encoder(
+ clip_image,
+ output_hidden_states=True,
+ )
+ else:
+ with torch.no_grad():
+ self.vision_encoder.eval()
+ id_embeds = self.vision_encoder(
+ clip_image, output_hidden_states=True
+ )
+
+ img_embeds = id_embeds['last_hidden_state']
+
+ if self.config.quad_image:
+ # get the outputs of the quat
+ chunks = img_embeds.chunk(quad_count, dim=0)
+ chunk_sum = torch.zeros_like(chunks[0])
+ for chunk in chunks:
+ chunk_sum = chunk_sum + chunk
+ # get the mean of them
+
+ img_embeds = chunk_sum / quad_count
+
+ if not is_training or not self.config.train_image_encoder:
+ img_embeds = img_embeds.detach()
+
+ img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype)))
+
+ prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2)
+ return prompt_embeds
+ else:
+ return prompt_embeds
+
+ def get_empty_clip_image(self, batch_size: int, shape=None) -> torch.Tensor:
+ with torch.no_grad():
+ if shape is None:
+ shape = [batch_size, 3, self.input_size, self.input_size]
+ tensors_0_1 = torch.rand(shape, device=self.device)
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ # tensors_0_1 = tensors_0_1 * 0
+ mean = torch.tensor(self.clip_image_processor.image_mean).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ std = torch.tensor(self.clip_image_processor.image_std).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
+ clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
+ return clip_image.detach()
+
+ def train(self, mode: bool = True):
+ if self.config.train_image_encoder:
+ self.vision_encoder.train(mode)
+ super().train(mode)
+
+ def trigger_pre_te(
+ self,
+ tensors_0_1: torch.Tensor,
+ is_training=False,
+ has_been_preprocessed=False,
+ quad_count=4,
+ batch_size=1,
+ ) -> PromptEmbeds:
+ if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
+ skip_unconditional = self.sd_ref().is_flux
+ if tensors_0_1 is None:
+ tensors_0_1 = self.get_empty_clip_image(batch_size)
+ has_been_preprocessed = True
+
+ with torch.no_grad():
+ # on training the clip image is created in the dataloader
+ if not has_been_preprocessed:
+ # tensors should be 0-1
+ if tensors_0_1.ndim == 3:
+ tensors_0_1 = tensors_0_1.unsqueeze(0)
+ # training tensors are 0 - 1
+ tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
+ # if images are out of this range throw error
+ if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
+ raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
+ tensors_0_1.min(), tensors_0_1.max()
+ ))
+ clip_image = self.image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+ else:
+ clip_image = tensors_0_1
+
+ # if is pixtral
+ if self.config.image_encoder_arch == 'pixtral' and self.config.pixtral_random_image_size:
+ # get the random size
+ random_size = random.randint(256, self.config.pixtral_max_image_size)
+ # images are already sized for max size, we have to fit them to the pixtral patch size to reduce / enlarge it farther.
+ h, w = clip_image.shape[2], clip_image.shape[3]
+ current_base_size = int(math.sqrt(w * h))
+ ratio = current_base_size / random_size
+ if ratio > 1:
+ w = round(w / ratio)
+ h = round(h / ratio)
+
+ width_tokens = (w - 1) // self.image_processor.image_patch_size + 1
+ height_tokens = (h - 1) // self.image_processor.image_patch_size + 1
+ assert width_tokens > 0
+ assert height_tokens > 0
+
+ new_image_size = (
+ width_tokens * self.image_processor.image_patch_size,
+ height_tokens * self.image_processor.image_patch_size,
+ )
+
+ # resize the image
+ clip_image = F.interpolate(clip_image, size=new_image_size, mode='bicubic', align_corners=False)
+
+
+ batch_size = clip_image.shape[0]
+ if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional:
+ # add an unconditional so we can save it
+ unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
+ clip_image.device, dtype=clip_image.dtype
+ )
+ clip_image = torch.cat([unconditional, clip_image], dim=0)
+
+ clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
+
+ if self.config.quad_image:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = clip_image.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ to_cat = []
+ for i, ci in enumerate([ci1, ci2, ci3, ci4]):
+ if i < quad_count:
+ to_cat.append(ci)
+ else:
+ break
+
+ clip_image = torch.cat(to_cat, dim=0).detach()
+
+ if self.adapter_type == 'ilora':
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.vision_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ id_embeds = self.vision_encoder(
+ clip_image,
+ output_hidden_states=True,
+ )
+ else:
+ with torch.no_grad():
+ self.vision_encoder.eval()
+ id_embeds = self.vision_encoder(
+ clip_image, output_hidden_states=True
+ )
+
+ if self.config.clip_layer == 'penultimate_hidden_states':
+ img_embeds = id_embeds.hidden_states[-2]
+ elif self.config.clip_layer == 'last_hidden_state':
+ img_embeds = id_embeds.hidden_states[-1]
+ elif self.config.clip_layer == 'image_embeds':
+ img_embeds = id_embeds.image_embeds
+ else:
+ raise ValueError(f"unknown clip layer: {self.config.clip_layer}")
+
+ if self.config.quad_image:
+ # get the outputs of the quat
+ chunks = img_embeds.chunk(quad_count, dim=0)
+ chunk_sum = torch.zeros_like(chunks[0])
+ for chunk in chunks:
+ chunk_sum = chunk_sum + chunk
+ # get the mean of them
+
+ img_embeds = chunk_sum / quad_count
+
+ if not is_training or not self.config.train_image_encoder:
+ img_embeds = img_embeds.detach()
+
+ self.ilora_module(img_embeds)
+ if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.vision_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ else:
+ with torch.no_grad():
+ self.vision_encoder.eval()
+ clip_output = self.vision_encoder(
+ clip_image,
+ output_hidden_states=True,
+ )
+ if self.config.clip_layer == 'penultimate_hidden_states':
+ # they skip last layer for ip+
+ # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
+ clip_image_embeds = clip_output.hidden_states[-2]
+ elif self.config.clip_layer == 'last_hidden_state':
+ clip_image_embeds = clip_output.hidden_states[-1]
+ else:
+ if hasattr(clip_output, 'image_embeds'):
+ clip_image_embeds = clip_output.image_embeds
+ elif hasattr(clip_output, 'pooler_output'):
+ clip_image_embeds = clip_output.pooler_output
+ # TODO should we always norm image embeds?
+ # get norm embeddings
+ # l2_norm = torch.norm(clip_image_embeds, p=2)
+ # clip_image_embeds = clip_image_embeds / l2_norm
+
+ if not is_training or not self.config.train_image_encoder:
+ clip_image_embeds = clip_image_embeds.detach()
+
+ if self.adapter_type == 'te_augmenter':
+ clip_image_embeds = self.te_augmenter(clip_image_embeds)
+
+ if self.adapter_type == 'vision_direct':
+ clip_image_embeds = self.vd_adapter(clip_image_embeds)
+
+ # save them to the conditional and unconditional
+ try:
+ if skip_unconditional:
+ self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds
+ else:
+ self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
+ except ValueError:
+ raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
+
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
+ if self.config.train_only_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ return
+ if self.config.type == 'photo_maker':
+ yield from self.fuse_module.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ elif self.config.type == 'clip_fusion':
+ yield from self.clip_fusion_module.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ elif self.config.type == 'ilora':
+ yield from self.ilora_module.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ elif self.config.type == 'text_encoder':
+ for attn_processor in self.te_adapter.adapter_modules:
+ yield from attn_processor.parameters(recurse)
+ elif self.config.type == 'vision_direct':
+ if self.config.train_scaler:
+ # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
+ yield self.vd_adapter.block_scaler
+ else:
+ for attn_processor in self.vd_adapter.adapter_modules:
+ yield from attn_processor.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ if self.vd_adapter.resampler is not None:
+ yield from self.vd_adapter.resampler.parameters(recurse)
+ if self.vd_adapter.pool is not None:
+ yield from self.vd_adapter.pool.parameters(recurse)
+ if self.vd_adapter.sparse_autoencoder is not None:
+ yield from self.vd_adapter.sparse_autoencoder.parameters(recurse)
+ elif self.config.type == 'te_augmenter':
+ yield from self.te_augmenter.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.vision_encoder.parameters(recurse)
+ elif self.config.type == 'single_value':
+ yield from self.single_value_adapter.parameters(recurse)
+ elif self.config.type == 'redux':
+ yield from self.redux_adapter.parameters(recurse)
+ else:
+ raise NotImplementedError
+
+ def enable_gradient_checkpointing(self):
+ if hasattr(self.vision_encoder, "enable_gradient_checkpointing"):
+ self.vision_encoder.enable_gradient_checkpointing()
+ elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
+ self.vision_encoder.gradient_checkpointing = True
+
+ def get_additional_save_metadata(self) -> Dict[str, Any]:
+ additional = {}
+ if self.config.type == 'ilora':
+ extra = self.ilora_module.get_additional_save_metadata()
+ for k, v in extra.items():
+ additional[k] = v
+ additional['clip_layer'] = self.config.clip_layer
+ additional['image_encoder_arch'] = self.config.head_dim
+ return additional
+
+ def post_weight_update(self):
+ # do any kind of updates after the weight update
+ if self.config.type == 'vision_direct':
+ self.vd_adapter.post_weight_update()
+ pass
\ No newline at end of file
diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5285b371734e874c831116fb6799fa99f1df5d48
--- /dev/null
+++ b/toolkit/data_loader.py
@@ -0,0 +1,677 @@
+import copy
+import json
+import os
+import random
+import traceback
+from functools import lru_cache
+from typing import List, TYPE_CHECKING
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torchvision import transforms
+from torch.utils.data import Dataset, DataLoader, ConcatDataset
+from tqdm import tqdm
+import albumentations as A
+
+from toolkit.buckets import get_bucket_for_image_size, BucketResolution
+from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
+from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments, CLIPCachingMixin
+from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
+
+import platform
+
+def is_native_windows():
+ return platform.system() == "Windows" and platform.release() != "2"
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+class RescaleTransform:
+ """Transform to rescale images to the range [-1, 1]."""
+
+ def __call__(self, image):
+ return image * 2 - 1
+
+
+class NormalizeSDXLTransform:
+ """
+ Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
+
+ Mean: tensor([ 0.0002, -0.1034, -0.1879])
+ Standard Deviation: tensor([0.5436, 0.5116, 0.5033])
+ """
+
+ def __call__(self, image):
+ return transforms.Normalize(
+ mean=[0.0002, -0.1034, -0.1879],
+ std=[0.5436, 0.5116, 0.5033],
+ )(image)
+
+
+class NormalizeSD15Transform:
+ """
+ Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images
+
+ Mean: tensor([-0.1600, -0.2450, -0.3227])
+ Standard Deviation: tensor([0.5319, 0.4997, 0.5139])
+
+ """
+
+ def __call__(self, image):
+ return transforms.Normalize(
+ mean=[-0.1600, -0.2450, -0.3227],
+ std=[0.5319, 0.4997, 0.5139],
+ )(image)
+
+
+
+class ImageDataset(Dataset, CaptionMixin):
+ def __init__(self, config):
+ self.config = config
+ self.name = self.get_config('name', 'dataset')
+ self.path = self.get_config('path', required=True)
+ self.scale = self.get_config('scale', 1)
+ self.random_scale = self.get_config('random_scale', False)
+ self.include_prompt = self.get_config('include_prompt', False)
+ self.default_prompt = self.get_config('default_prompt', '')
+ if self.include_prompt:
+ self.caption_type = self.get_config('caption_ext', 'txt')
+ else:
+ self.caption_type = None
+ # we always random crop if random scale is enabled
+ self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
+
+ self.resolution = self.get_config('resolution', 256)
+ self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
+ file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
+
+ # this might take a while
+ print(f" - Preprocessing image dimensions")
+ new_file_list = []
+ bad_count = 0
+ for file in tqdm(self.file_list):
+ img = Image.open(file)
+ if int(min(img.size) * self.scale) >= self.resolution:
+ new_file_list.append(file)
+ else:
+ bad_count += 1
+
+ self.file_list = new_file_list
+
+ print(f" - Found {len(self.file_list)} images")
+ print(f" - Found {bad_count} images that are too small")
+ assert len(self.file_list) > 0, f"no images found in {self.path}"
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ RescaleTransform(),
+ ])
+
+ def get_config(self, key, default=None, required=False):
+ if key in self.config:
+ value = self.config[key]
+ return value
+ elif required:
+ raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
+ else:
+ return default
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def __getitem__(self, index):
+ img_path = self.file_list[index]
+ try:
+ img = exif_transpose(Image.open(img_path)).convert('RGB')
+ except Exception as e:
+ print(f"Error opening image: {img_path}")
+ print(e)
+ # make a noise image if we can't open it
+ img = Image.fromarray(np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8))
+
+ # Downscale the source image first
+ img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
+ min_img_size = min(img.size)
+
+ if self.random_crop:
+ if self.random_scale and min_img_size > self.resolution:
+ if min_img_size < self.resolution:
+ print(
+ f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
+ scale_size = self.resolution
+ else:
+ scale_size = random.randint(self.resolution, int(min_img_size))
+ scaler = scale_size / min_img_size
+ scale_width = int((img.width + 5) * scaler)
+ scale_height = int((img.height + 5) * scaler)
+ img = img.resize((scale_width, scale_height), Image.BICUBIC)
+ img = transforms.RandomCrop(self.resolution)(img)
+ else:
+ img = transforms.CenterCrop(min_img_size)(img)
+ img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
+
+ img = self.transform(img)
+
+ if self.include_prompt:
+ prompt = self.get_caption_item(index)
+ return img, prompt
+ else:
+ return img
+
+
+
+
+
+class AugmentedImageDataset(ImageDataset):
+ def __init__(self, config):
+ super().__init__(config)
+ self.augmentations = self.get_config('augmentations', [])
+ self.augmentations = [Augments(**aug) for aug in self.augmentations]
+
+ augmentation_list = []
+ for aug in self.augmentations:
+ # make sure method name is valid
+ assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
+ # get the method
+ method = getattr(A, aug.method_name)
+ # add the method to the list
+ augmentation_list.append(method(**aug.params))
+
+ self.aug_transform = A.Compose(augmentation_list)
+ self.original_transform = self.transform
+ # replace transform so we get raw pil image
+ self.transform = transforms.Compose([])
+
+ def __getitem__(self, index):
+ # get the original image
+ # image is a PIL image, convert to bgr
+ pil_image = super().__getitem__(index)
+ open_cv_image = np.array(pil_image)
+ # Convert RGB to BGR
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
+
+ # apply augmentations
+ augmented = self.aug_transform(image=open_cv_image)["image"]
+
+ # convert back to RGB tensor
+ augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
+
+ # convert to PIL image
+ augmented = Image.fromarray(augmented)
+
+ # return both # return image as 0 - 1 tensor
+ return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
+
+
+class PairedImageDataset(Dataset):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.size = self.get_config('size', 512)
+ self.path = self.get_config('path', None)
+ self.pos_folder = self.get_config('pos_folder', None)
+ self.neg_folder = self.get_config('neg_folder', None)
+
+ self.default_prompt = self.get_config('default_prompt', '')
+ self.network_weight = self.get_config('network_weight', 1.0)
+ self.pos_weight = self.get_config('pos_weight', self.network_weight)
+ self.neg_weight = self.get_config('neg_weight', self.network_weight)
+
+ supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP')
+
+ if self.pos_folder is not None and self.neg_folder is not None:
+ # find matching files
+ self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if
+ file.lower().endswith(supported_exts)]
+ self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if
+ file.lower().endswith(supported_exts)]
+
+ matched_files = []
+ for pos_file in self.pos_file_list:
+ pos_file_no_ext = os.path.splitext(pos_file)[0]
+ for neg_file in self.neg_file_list:
+ neg_file_no_ext = os.path.splitext(neg_file)[0]
+ if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext):
+ matched_files.append((neg_file, pos_file))
+ break
+
+ # remove duplicates
+ matched_files = [t for t in (set(tuple(i) for i in matched_files))]
+
+ self.file_list = matched_files
+ print(f" - Found {len(self.file_list)} matching pairs")
+ else:
+ self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
+ file.lower().endswith(supported_exts)]
+ print(f" - Found {len(self.file_list)} images")
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ RescaleTransform(),
+ ])
+
+ def get_all_prompts(self):
+ prompts = []
+ for index in range(len(self.file_list)):
+ prompts.append(self.get_prompt_item(index))
+
+ # remove duplicates
+ prompts = list(set(prompts))
+ return prompts
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def get_config(self, key, default=None, required=False):
+ if key in self.config:
+ value = self.config[key]
+ return value
+ elif required:
+ raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
+ else:
+ return default
+
+ def get_prompt_item(self, index):
+ img_path_or_tuple = self.file_list[index]
+ if isinstance(img_path_or_tuple, tuple):
+ # check if either has a prompt file
+ path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
+ prompt_path = path_no_ext + '.txt'
+ if not os.path.exists(prompt_path):
+ path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
+ prompt_path = path_no_ext + '.txt'
+ else:
+ img_path = img_path_or_tuple
+ # see if prompt file exists
+ path_no_ext = os.path.splitext(img_path)[0]
+ prompt_path = path_no_ext + '.txt'
+
+ if os.path.exists(prompt_path):
+ with open(prompt_path, 'r', encoding='utf-8') as f:
+ prompt = f.read()
+ # remove any newlines
+ prompt = prompt.replace('\n', ', ')
+ # remove new lines for all operating systems
+ prompt = prompt.replace('\r', ', ')
+ prompt_split = prompt.split(',')
+ # remove empty strings
+ prompt_split = [p.strip() for p in prompt_split if p.strip()]
+ # join back together
+ prompt = ', '.join(prompt_split)
+ else:
+ prompt = self.default_prompt
+ return prompt
+
+ def __getitem__(self, index):
+ img_path_or_tuple = self.file_list[index]
+ if isinstance(img_path_or_tuple, tuple):
+ # load both images
+ img_path = img_path_or_tuple[0]
+ img1 = exif_transpose(Image.open(img_path)).convert('RGB')
+ img_path = img_path_or_tuple[1]
+ img2 = exif_transpose(Image.open(img_path)).convert('RGB')
+
+ # always use # 2 (pos)
+ bucket_resolution = get_bucket_for_image_size(
+ width=img2.width,
+ height=img2.height,
+ resolution=self.size,
+ # divisibility=self.
+ )
+
+ # images will be same base dimension, but may be trimmed. We need to shrink and then central crop
+ if bucket_resolution['width'] > bucket_resolution['height']:
+ img1_scale_to_height = bucket_resolution["height"]
+ img1_scale_to_width = int(img1.width * (bucket_resolution["height"] / img1.height))
+ img2_scale_to_height = bucket_resolution["height"]
+ img2_scale_to_width = int(img2.width * (bucket_resolution["height"] / img2.height))
+ else:
+ img1_scale_to_width = bucket_resolution["width"]
+ img1_scale_to_height = int(img1.height * (bucket_resolution["width"] / img1.width))
+ img2_scale_to_width = bucket_resolution["width"]
+ img2_scale_to_height = int(img2.height * (bucket_resolution["width"] / img2.width))
+
+ img1_crop_height = bucket_resolution["height"]
+ img1_crop_width = bucket_resolution["width"]
+ img2_crop_height = bucket_resolution["height"]
+ img2_crop_width = bucket_resolution["width"]
+
+ # scale then center crop images
+ img1 = img1.resize((img1_scale_to_width, img1_scale_to_height), Image.BICUBIC)
+ img1 = transforms.CenterCrop((img1_crop_height, img1_crop_width))(img1)
+ img2 = img2.resize((img2_scale_to_width, img2_scale_to_height), Image.BICUBIC)
+ img2 = transforms.CenterCrop((img2_crop_height, img2_crop_width))(img2)
+
+ # combine them side by side
+ img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
+ img.paste(img1, (0, 0))
+ img.paste(img2, (img1.width, 0))
+ else:
+ img_path = img_path_or_tuple
+ img = exif_transpose(Image.open(img_path)).convert('RGB')
+ height = self.size
+ # determine width to keep aspect ratio
+ width = int(img.size[0] * height / img.size[1])
+
+ # Downscale the source image first
+ img = img.resize((width, height), Image.BICUBIC)
+
+ prompt = self.get_prompt_item(index)
+ img = self.transform(img)
+
+ return img, prompt, (self.neg_weight, self.pos_weight)
+
+
+class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, CaptionMixin, Dataset):
+
+ def __init__(
+ self,
+ dataset_config: 'DatasetConfig',
+ batch_size=1,
+ sd: 'StableDiffusion' = None,
+ ):
+ super().__init__()
+ self.dataset_config = dataset_config
+ folder_path = dataset_config.folder_path
+ self.dataset_path = dataset_config.dataset_path
+ if self.dataset_path is None:
+ self.dataset_path = folder_path
+
+ self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
+ self.is_caching_latents_to_memory = dataset_config.cache_latents
+ self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
+ self.is_caching_clip_vision_to_disk = dataset_config.cache_clip_vision_to_disk
+ self.epoch_num = 0
+
+ self.sd = sd
+
+ if self.sd is None and self.is_caching_latents:
+ raise ValueError(f"sd is required for caching latents")
+
+ self.caption_type = dataset_config.caption_ext
+ self.default_caption = dataset_config.default_caption
+ self.random_scale = dataset_config.random_scale
+ self.scale = dataset_config.scale
+ self.batch_size = batch_size
+ # we always random crop if random scale is enabled
+ self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
+ self.resolution = dataset_config.resolution
+ self.caption_dict = None
+ self.file_list: List['FileItemDTO'] = []
+
+ # check if dataset_path is a folder or json
+ if os.path.isdir(self.dataset_path):
+ file_list = [os.path.join(root, file) for root, _, files in os.walk(self.dataset_path) for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
+ else:
+ # assume json
+ with open(self.dataset_path, 'r') as f:
+ self.caption_dict = json.load(f)
+ # keys are file paths
+ file_list = list(self.caption_dict.keys())
+
+ if self.dataset_config.num_repeats > 1:
+ # repeat the list
+ file_list = file_list * self.dataset_config.num_repeats
+
+ if self.dataset_config.standardize_images:
+ if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd:
+ NormalizeMethod = NormalizeSDXLTransform
+ else:
+ NormalizeMethod = NormalizeSD15Transform
+
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ RescaleTransform(),
+ NormalizeMethod(),
+ ])
+ else:
+ self.transform = transforms.Compose([
+ transforms.ToTensor(),
+ RescaleTransform(),
+ ])
+
+ # this might take a while
+ print(f"Dataset: {self.dataset_path}")
+ print(f" - Preprocessing image dimensions")
+ dataset_folder = self.dataset_path
+ if not os.path.isdir(self.dataset_path):
+ dataset_folder = os.path.dirname(dataset_folder)
+ dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json')
+ dataloader_version = "0.1.1"
+ if os.path.exists(dataset_size_file):
+ try:
+ with open(dataset_size_file, 'r') as f:
+ self.size_database = json.load(f)
+
+ if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version:
+ print("Upgrading size database to new version")
+ # old version, delete and recreate
+ self.size_database = {}
+ except Exception as e:
+ print(f"Error loading size database: {dataset_size_file}")
+ print(e)
+ self.size_database = {}
+ else:
+ self.size_database = {}
+
+ self.size_database["__version__"] = dataloader_version
+
+ bad_count = 0
+ for file in tqdm(file_list):
+ try:
+ file_item = FileItemDTO(
+ sd=self.sd,
+ path=file,
+ dataset_config=dataset_config,
+ dataloader_transforms=self.transform,
+ size_database=self.size_database,
+ dataset_root=dataset_folder,
+ )
+ self.file_list.append(file_item)
+ except Exception as e:
+ print(traceback.format_exc())
+ print(f"Error processing image: {file}")
+ print(e)
+ bad_count += 1
+
+ # save the size database
+ with open(dataset_size_file, 'w') as f:
+ json.dump(self.size_database, f)
+
+ print(f" - Found {len(self.file_list)} images")
+ # print(f" - Found {bad_count} images that are too small")
+ assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
+
+ # handle x axis flips
+ if self.dataset_config.flip_x:
+ print(" - adding x axis flips")
+ current_file_list = [x for x in self.file_list]
+ for file_item in current_file_list:
+ # create a copy that is flipped on the x axis
+ new_file_item = copy.deepcopy(file_item)
+ new_file_item.flip_x = True
+ self.file_list.append(new_file_item)
+
+ # handle y axis flips
+ if self.dataset_config.flip_y:
+ print(" - adding y axis flips")
+ current_file_list = [x for x in self.file_list]
+ for file_item in current_file_list:
+ # create a copy that is flipped on the y axis
+ new_file_item = copy.deepcopy(file_item)
+ new_file_item.flip_y = True
+ self.file_list.append(new_file_item)
+
+ if self.dataset_config.flip_x or self.dataset_config.flip_y:
+ print(f" - Found {len(self.file_list)} images after adding flips")
+
+
+ self.setup_epoch()
+
+ def setup_epoch(self):
+ if self.epoch_num == 0:
+ # initial setup
+ # do not call for now
+ if self.dataset_config.buckets:
+ # setup buckets
+ self.setup_buckets()
+ if self.is_caching_latents:
+ self.cache_latents_all_latents()
+ if self.is_caching_clip_vision_to_disk:
+ self.cache_clip_vision_to_disk()
+ else:
+ if self.dataset_config.poi is not None:
+ # handle cropping to a specific point of interest
+ # setup buckets every epoch
+ self.setup_buckets(quiet=True)
+ self.epoch_num += 1
+
+ def __len__(self):
+ if self.dataset_config.buckets:
+ return len(self.batch_indices)
+ return len(self.file_list)
+
+ def _get_single_item(self, index) -> 'FileItemDTO':
+ file_item = copy.deepcopy(self.file_list[index])
+ file_item.load_and_process_image(self.transform)
+ file_item.load_caption(self.caption_dict)
+ return file_item
+
+ def __getitem__(self, item):
+ if self.dataset_config.buckets:
+ # for buckets we collate ourselves for now
+ # todo allow a scheduler to dynamically make buckets
+ # we collate ourselves
+ if len(self.batch_indices) - 1 < item:
+ # tried everything to solve this. No way to reset length when redoing things. Pick another index
+ item = random.randint(0, len(self.batch_indices) - 1)
+ idx_list = self.batch_indices[item]
+ return [self._get_single_item(idx) for idx in idx_list]
+ else:
+ # Dataloader is batching
+ return self._get_single_item(item)
+
+
+def get_dataloader_from_datasets(
+ dataset_options,
+ batch_size=1,
+ sd: 'StableDiffusion' = None,
+) -> DataLoader:
+ if dataset_options is None or len(dataset_options) == 0:
+ return None
+
+ datasets = []
+ has_buckets = False
+ is_caching_latents = False
+
+ dataset_config_list = []
+ # preprocess them all
+ for dataset_option in dataset_options:
+ if isinstance(dataset_option, DatasetConfig):
+ dataset_config_list.append(dataset_option)
+ else:
+ # preprocess raw data
+ split_configs = preprocess_dataset_raw_config([dataset_option])
+ for x in split_configs:
+ dataset_config_list.append(DatasetConfig(**x))
+
+ for config in dataset_config_list:
+
+ if config.type == 'image':
+ dataset = AiToolkitDataset(config, batch_size=batch_size, sd=sd)
+ datasets.append(dataset)
+ if config.buckets:
+ has_buckets = True
+ if config.cache_latents or config.cache_latents_to_disk:
+ is_caching_latents = True
+ else:
+ raise ValueError(f"invalid dataset type: {config.type}")
+
+ concatenated_dataset = ConcatDataset(datasets)
+
+ # todo build scheduler that can get buckets from all datasets that match
+ # todo and evenly distribute reg images
+
+ def dto_collation(batch: List['FileItemDTO']):
+ # create DTO batch
+ batch = DataLoaderBatchDTO(
+ file_items=batch
+ )
+ return batch
+
+ # check if is caching latents
+
+ dataloader_kwargs = {}
+
+ if is_native_windows():
+ dataloader_kwargs['num_workers'] = 0
+ else:
+ dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers
+ dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor
+
+ if has_buckets:
+ # make sure they all have buckets
+ for dataset in datasets:
+ assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
+
+ data_loader = DataLoader(
+ concatenated_dataset,
+ batch_size=None, # we batch in the datasets for now
+ drop_last=False,
+ shuffle=True,
+ collate_fn=dto_collation, # Use the custom collate function
+ **dataloader_kwargs
+ )
+ else:
+ data_loader = DataLoader(
+ concatenated_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ collate_fn=dto_collation,
+ **dataloader_kwargs
+ )
+ return data_loader
+
+
+def trigger_dataloader_setup_epoch(dataloader: DataLoader):
+ # hacky but needed because of different types of datasets and dataloaders
+ dataloader.len = None
+ if isinstance(dataloader.dataset, list):
+ for dataset in dataloader.dataset:
+ if hasattr(dataset, 'datasets'):
+ for sub_dataset in dataset.datasets:
+ if hasattr(sub_dataset, 'setup_epoch'):
+ sub_dataset.setup_epoch()
+ sub_dataset.len = None
+ elif hasattr(dataset, 'setup_epoch'):
+ dataset.setup_epoch()
+ dataset.len = None
+ elif hasattr(dataloader.dataset, 'setup_epoch'):
+ dataloader.dataset.setup_epoch()
+ dataloader.dataset.len = None
+ elif hasattr(dataloader.dataset, 'datasets'):
+ dataloader.dataset.len = None
+ for sub_dataset in dataloader.dataset.datasets:
+ if hasattr(sub_dataset, 'setup_epoch'):
+ sub_dataset.setup_epoch()
+ sub_dataset.len = None
+
+def get_dataloader_datasets(dataloader: DataLoader):
+ # hacky but needed because of different types of datasets and dataloaders
+ if isinstance(dataloader.dataset, list):
+ datasets = []
+ for dataset in dataloader.dataset:
+ if hasattr(dataset, 'datasets'):
+ for sub_dataset in dataset.datasets:
+ datasets.append(sub_dataset)
+ else:
+ datasets.append(dataset)
+ return datasets
+ elif hasattr(dataloader.dataset, 'datasets'):
+ return dataloader.dataset.datasets
+ else:
+ return [dataloader.dataset]
diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..34239f4066a2dd885e27ff33e50e26391c17322d
--- /dev/null
+++ b/toolkit/data_transfer_object/data_loader.py
@@ -0,0 +1,252 @@
+import os
+import weakref
+from _weakref import ReferenceType
+from typing import TYPE_CHECKING, List, Union
+import torch
+import random
+
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+
+from toolkit import image_utils
+from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
+ ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
+ UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
+
+
+if TYPE_CHECKING:
+ from toolkit.config_modules import DatasetConfig
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+printed_messages = []
+
+
+def print_once(msg):
+ global printed_messages
+ if msg not in printed_messages:
+ print(msg)
+ printed_messages.append(msg)
+
+
+class FileItemDTO(
+ LatentCachingFileItemDTOMixin,
+ CaptionProcessingDTOMixin,
+ ImageProcessingDTOMixin,
+ ControlFileItemDTOMixin,
+ ClipImageFileItemDTOMixin,
+ MaskFileItemDTOMixin,
+ AugmentationFileItemDTOMixin,
+ UnconditionalFileItemDTOMixin,
+ PoiFileItemDTOMixin,
+ ArgBreakMixin,
+):
+ def __init__(self, *args, **kwargs):
+ self.path = kwargs.get('path', '')
+ self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+ size_database = kwargs.get('size_database', {})
+ dataset_root = kwargs.get('dataset_root', None)
+ if dataset_root is not None:
+ # remove dataset root from path
+ file_key = self.path.replace(dataset_root, '')
+ else:
+ file_key = os.path.basename(self.path)
+ if file_key in size_database:
+ w, h = size_database[file_key]
+ else:
+ # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now.
+ # process width and height
+ # try:
+ # w, h = image_utils.get_image_size(self.path)
+ # except image_utils.UnknownImageFormat:
+ # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
+ # f'This process is faster for png, jpeg')
+ img = exif_transpose(Image.open(self.path))
+ w, h = img.size
+ size_database[file_key] = (w, h)
+ self.width: int = w
+ self.height: int = h
+ self.dataloader_transforms = kwargs.get('dataloader_transforms', None)
+ super().__init__(*args, **kwargs)
+
+ # self.caption_path: str = kwargs.get('caption_path', None)
+ self.raw_caption: str = kwargs.get('raw_caption', None)
+ # we scale first, then crop
+ self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale))
+ self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale))
+ # crop values are from scaled size
+ self.crop_x: int = kwargs.get('crop_x', 0)
+ self.crop_y: int = kwargs.get('crop_y', 0)
+ self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
+ self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
+ self.flip_x: bool = kwargs.get('flip_x', False)
+ self.flip_y: bool = kwargs.get('flip_x', False)
+ self.augments: List[str] = self.dataset_config.augments
+ self.loss_multiplier: float = self.dataset_config.loss_multiplier
+
+ self.network_weight: float = self.dataset_config.network_weight
+ self.is_reg = self.dataset_config.is_reg
+ self.tensor: Union[torch.Tensor, None] = None
+
+ def cleanup(self):
+ self.tensor = None
+ self.cleanup_latent()
+ self.cleanup_control()
+ self.cleanup_clip_image()
+ self.cleanup_mask()
+ self.cleanup_unconditional()
+
+
+class DataLoaderBatchDTO:
+ def __init__(self, **kwargs):
+ try:
+ self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
+ is_latents_cached = self.file_items[0].is_latent_cached
+ self.tensor: Union[torch.Tensor, None] = None
+ self.latents: Union[torch.Tensor, None] = None
+ self.control_tensor: Union[torch.Tensor, None] = None
+ self.clip_image_tensor: Union[torch.Tensor, None] = None
+ self.mask_tensor: Union[torch.Tensor, None] = None
+ self.unaugmented_tensor: Union[torch.Tensor, None] = None
+ self.unconditional_tensor: Union[torch.Tensor, None] = None
+ self.unconditional_latents: Union[torch.Tensor, None] = None
+ self.clip_image_embeds: Union[List[dict], None] = None
+ self.clip_image_embeds_unconditional: Union[List[dict], None] = None
+ self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
+ self.extra_values: Union[torch.Tensor, None] = torch.tensor([x.extra_values for x in self.file_items]) if len(self.file_items[0].extra_values) > 0 else None
+ if not is_latents_cached:
+ # only return a tensor if latents are not cached
+ self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
+ # if we have encoded latents, we concatenate them
+ self.latents: Union[torch.Tensor, None] = None
+ if is_latents_cached:
+ self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
+ self.control_tensor: Union[torch.Tensor, None] = None
+ # if self.file_items[0].control_tensor is not None:
+ # if any have a control tensor, we concatenate them
+ if any([x.control_tensor is not None for x in self.file_items]):
+ # find one to use as a base
+ base_control_tensor = None
+ for x in self.file_items:
+ if x.control_tensor is not None:
+ base_control_tensor = x.control_tensor
+ break
+ control_tensors = []
+ for x in self.file_items:
+ if x.control_tensor is None:
+ control_tensors.append(torch.zeros_like(base_control_tensor))
+ else:
+ control_tensors.append(x.control_tensor)
+ self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
+
+ self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items]
+
+ if any([x.clip_image_tensor is not None for x in self.file_items]):
+ # find one to use as a base
+ base_clip_image_tensor = None
+ for x in self.file_items:
+ if x.clip_image_tensor is not None:
+ base_clip_image_tensor = x.clip_image_tensor
+ break
+ clip_image_tensors = []
+ for x in self.file_items:
+ if x.clip_image_tensor is None:
+ clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
+ else:
+ clip_image_tensors.append(x.clip_image_tensor)
+ self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
+
+ if any([x.mask_tensor is not None for x in self.file_items]):
+ # find one to use as a base
+ base_mask_tensor = None
+ for x in self.file_items:
+ if x.mask_tensor is not None:
+ base_mask_tensor = x.mask_tensor
+ break
+ mask_tensors = []
+ for x in self.file_items:
+ if x.mask_tensor is None:
+ mask_tensors.append(torch.zeros_like(base_mask_tensor))
+ else:
+ mask_tensors.append(x.mask_tensor)
+ self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
+
+ # add unaugmented tensors for ones with augments
+ if any([x.unaugmented_tensor is not None for x in self.file_items]):
+ # find one to use as a base
+ base_unaugmented_tensor = None
+ for x in self.file_items:
+ if x.unaugmented_tensor is not None:
+ base_unaugmented_tensor = x.unaugmented_tensor
+ break
+ unaugmented_tensor = []
+ for x in self.file_items:
+ if x.unaugmented_tensor is None:
+ unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
+ else:
+ unaugmented_tensor.append(x.unaugmented_tensor)
+ self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
+
+ # add unconditional tensors
+ if any([x.unconditional_tensor is not None for x in self.file_items]):
+ # find one to use as a base
+ base_unconditional_tensor = None
+ for x in self.file_items:
+ if x.unaugmented_tensor is not None:
+ base_unconditional_tensor = x.unconditional_tensor
+ break
+ unconditional_tensor = []
+ for x in self.file_items:
+ if x.unconditional_tensor is None:
+ unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor))
+ else:
+ unconditional_tensor.append(x.unconditional_tensor)
+ self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor])
+
+ if any([x.clip_image_embeds is not None for x in self.file_items]):
+ self.clip_image_embeds = []
+ for x in self.file_items:
+ if x.clip_image_embeds is not None:
+ self.clip_image_embeds.append(x.clip_image_embeds)
+ else:
+ raise Exception("clip_image_embeds is None for some file items")
+
+ if any([x.clip_image_embeds_unconditional is not None for x in self.file_items]):
+ self.clip_image_embeds_unconditional = []
+ for x in self.file_items:
+ if x.clip_image_embeds_unconditional is not None:
+ self.clip_image_embeds_unconditional.append(x.clip_image_embeds_unconditional)
+ else:
+ raise Exception("clip_image_embeds_unconditional is None for some file items")
+
+ except Exception as e:
+ print(e)
+ raise e
+
+ def get_is_reg_list(self):
+ return [x.is_reg for x in self.file_items]
+
+ def get_network_weight_list(self):
+ return [x.network_weight for x in self.file_items]
+
+ def get_caption_list(
+ self,
+ trigger=None,
+ to_replace_list=None,
+ add_if_not_present=True
+ ):
+ return [x.caption for x in self.file_items]
+
+ def get_caption_short_list(
+ self,
+ trigger=None,
+ to_replace_list=None,
+ add_if_not_present=True
+ ):
+ return [x.caption_short for x in self.file_items]
+
+ def cleanup(self):
+ del self.latents
+ del self.tensor
+ del self.control_tensor
+ for file_item in self.file_items:
+ file_item.cleanup()
diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bba44316952858c0cbf870cb70ec85db4c8fb99
--- /dev/null
+++ b/toolkit/dataloader_mixins.py
@@ -0,0 +1,1630 @@
+import base64
+import glob
+import hashlib
+import json
+import math
+import os
+import random
+from collections import OrderedDict
+from typing import TYPE_CHECKING, List, Dict, Union
+
+import cv2
+import numpy as np
+import torch
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor
+
+from toolkit.basic import flush, value_map
+from toolkit.buckets import get_bucket_for_image_size, get_resolution
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.models.pixtral_vision import PixtralVisionImagePreprocessorCompatible
+from toolkit.prompt_utils import inject_trigger_into_prompt
+from torchvision import transforms
+from PIL import Image, ImageFilter, ImageOps
+from PIL.ImageOps import exif_transpose
+import albumentations as A
+
+from toolkit.train_tools import get_torch_dtype
+
+if TYPE_CHECKING:
+ from toolkit.data_loader import AiToolkitDataset
+ from toolkit.data_transfer_object.data_loader import FileItemDTO
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+# def get_associated_caption_from_img_path(img_path):
+# https://demo.albumentations.ai/
+class Augments:
+ def __init__(self, **kwargs):
+ self.method_name = kwargs.get('method', None)
+ self.params = kwargs.get('params', {})
+
+ # convert kwargs enums for cv2
+ for key, value in self.params.items():
+ if isinstance(value, str):
+ # split the string
+ split_string = value.split('.')
+ if len(split_string) == 2 and split_string[0] == 'cv2':
+ if hasattr(cv2, split_string[1]):
+ self.params[key] = getattr(cv2, split_string[1].upper())
+ else:
+ raise ValueError(f"invalid cv2 enum: {split_string[1]}")
+
+
+transforms_dict = {
+ 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
+ 'RandomEqualize': transforms.RandomEqualize(p=0.2),
+}
+
+caption_ext_list = ['txt', 'json', 'caption']
+
+
+def standardize_images(images):
+ """
+ Standardize the given batch of images using the specified mean and std.
+ Expects values of 0 - 1
+
+ Args:
+ images (torch.Tensor): A batch of images in the shape of (N, C, H, W),
+ where N is the number of images, C is the number of channels,
+ H is the height, and W is the width.
+
+ Returns:
+ torch.Tensor: Standardized images.
+ """
+ mean = [0.48145466, 0.4578275, 0.40821073]
+ std = [0.26862954, 0.26130258, 0.27577711]
+
+ # Define the normalization transform
+ normalize = transforms.Normalize(mean=mean, std=std)
+
+ # Apply normalization to each image in the batch
+ standardized_images = torch.stack([normalize(img) for img in images])
+
+ return standardized_images
+
+def clean_caption(caption):
+ # remove any newlines
+ caption = caption.replace('\n', ', ')
+ # remove new lines for all operating systems
+ caption = caption.replace('\r', ', ')
+ caption_split = caption.split(',')
+ # remove empty strings
+ caption_split = [p.strip() for p in caption_split if p.strip()]
+ # join back together
+ caption = ', '.join(caption_split)
+ return caption
+
+
+class CaptionMixin:
+ def get_caption_item(self: 'AiToolkitDataset', index):
+ if not hasattr(self, 'caption_type'):
+ raise Exception('caption_type not found on class instance')
+ if not hasattr(self, 'file_list'):
+ raise Exception('file_list not found on class instance')
+ img_path_or_tuple = self.file_list[index]
+ if isinstance(img_path_or_tuple, tuple):
+ img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
+ # check if either has a prompt file
+ path_no_ext = os.path.splitext(img_path)[0]
+ prompt_path = None
+ for ext in caption_ext_list:
+ prompt_path = path_no_ext + '.' + ext
+ if os.path.exists(prompt_path):
+ break
+ else:
+ img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
+ # see if prompt file exists
+ path_no_ext = os.path.splitext(img_path)[0]
+ prompt_path = None
+ for ext in caption_ext_list:
+ prompt_path = path_no_ext + '.' + ext
+ if os.path.exists(prompt_path):
+ break
+
+ # allow folders to have a default prompt
+ default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt')
+
+ if os.path.exists(prompt_path):
+ with open(prompt_path, 'r', encoding='utf-8') as f:
+ prompt = f.read()
+ # check if is json
+ if prompt_path.endswith('.json'):
+ prompt = json.loads(prompt)
+ if 'caption' in prompt:
+ prompt = prompt['caption']
+
+ prompt = clean_caption(prompt)
+ elif os.path.exists(default_prompt_path):
+ with open(default_prompt_path, 'r', encoding='utf-8') as f:
+ prompt = f.read()
+ prompt = clean_caption(prompt)
+ else:
+ prompt = ''
+ # get default_prompt if it exists on the class instance
+ if hasattr(self, 'default_prompt'):
+ prompt = self.default_prompt
+ if hasattr(self, 'default_caption'):
+ prompt = self.default_caption
+
+ # handle replacements
+ replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else []
+ for replacement in replacement_list:
+ from_string, to_string = replacement.split('|')
+ prompt = prompt.replace(from_string, to_string)
+
+ return prompt
+
+
+if TYPE_CHECKING:
+ from toolkit.config_modules import DatasetConfig
+ from toolkit.data_transfer_object.data_loader import FileItemDTO
+
+
+class Bucket:
+ def __init__(self, width: int, height: int):
+ self.width = width
+ self.height = height
+ self.file_list_idx: List[int] = []
+
+
+class BucketsMixin:
+ def __init__(self):
+ self.buckets: Dict[str, Bucket] = {}
+ self.batch_indices: List[List[int]] = []
+
+ def build_batch_indices(self: 'AiToolkitDataset'):
+ self.batch_indices = []
+ for key, bucket in self.buckets.items():
+ for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
+ end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
+ batch = bucket.file_list_idx[start_idx:end_idx]
+ self.batch_indices.append(batch)
+
+ def shuffle_buckets(self: 'AiToolkitDataset'):
+ for key, bucket in self.buckets.items():
+ random.shuffle(bucket.file_list_idx)
+
+ def setup_buckets(self: 'AiToolkitDataset', quiet=False):
+ if not hasattr(self, 'file_list'):
+ raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
+ if not hasattr(self, 'dataset_config'):
+ raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
+
+ if self.epoch_num > 0 and self.dataset_config.poi is None:
+ # no need to rebuild buckets for now
+ # todo handle random cropping for buckets
+ return
+ self.buckets = {} # clear it
+
+ config: 'DatasetConfig' = self.dataset_config
+ resolution = config.resolution
+ bucket_tolerance = config.bucket_tolerance
+ file_list: List['FileItemDTO'] = self.file_list
+
+ # for file_item in enumerate(file_list):
+ for idx, file_item in enumerate(file_list):
+ file_item: 'FileItemDTO' = file_item
+ width = int(file_item.width * file_item.dataset_config.scale)
+ height = int(file_item.height * file_item.dataset_config.scale)
+
+ did_process_poi = False
+ if file_item.has_point_of_interest:
+ # Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
+ did_process_poi = file_item.setup_poi_bucket()
+ if self.dataset_config.square_crop:
+ # we scale first so smallest size matches resolution
+ scale_factor_x = resolution / width
+ scale_factor_y = resolution / height
+ scale_factor = max(scale_factor_x, scale_factor_y)
+ file_item.scale_to_width = math.ceil(width * scale_factor)
+ file_item.scale_to_height = math.ceil(height * scale_factor)
+ file_item.crop_width = resolution
+ file_item.crop_height = resolution
+ if width > height:
+ file_item.crop_x = int(file_item.scale_to_width / 2 - resolution / 2)
+ file_item.crop_y = 0
+ else:
+ file_item.crop_x = 0
+ file_item.crop_y = int(file_item.scale_to_height / 2 - resolution / 2)
+ elif not did_process_poi:
+ bucket_resolution = get_bucket_for_image_size(
+ width, height,
+ resolution=resolution,
+ divisibility=bucket_tolerance
+ )
+
+ # Calculate scale factors for width and height
+ width_scale_factor = bucket_resolution["width"] / width
+ height_scale_factor = bucket_resolution["height"] / height
+
+ # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
+ max_scale_factor = max(width_scale_factor, height_scale_factor)
+
+ # round up
+ file_item.scale_to_width = int(math.ceil(width * max_scale_factor))
+ file_item.scale_to_height = int(math.ceil(height * max_scale_factor))
+
+ file_item.crop_height = bucket_resolution["height"]
+ file_item.crop_width = bucket_resolution["width"]
+
+ new_width = bucket_resolution["width"]
+ new_height = bucket_resolution["height"]
+
+ if self.dataset_config.random_crop:
+ # random crop
+ crop_x = random.randint(0, file_item.scale_to_width - new_width)
+ crop_y = random.randint(0, file_item.scale_to_height - new_height)
+ file_item.crop_x = crop_x
+ file_item.crop_y = crop_y
+ else:
+ # do central crop
+ file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
+ file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
+
+ if file_item.crop_y < 0 or file_item.crop_x < 0:
+ print('debug')
+
+ # check if bucket exists, if not, create it
+ bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
+ if bucket_key not in self.buckets:
+ self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height)
+ self.buckets[bucket_key].file_list_idx.append(idx)
+
+ # print the buckets
+ self.shuffle_buckets()
+ self.build_batch_indices()
+ if not quiet:
+ print(f'Bucket sizes for {self.dataset_path}:')
+ for key, bucket in self.buckets.items():
+ print(f'{key}: {len(bucket.file_list_idx)} files')
+ print(f'{len(self.buckets)} buckets made')
+
+
+class CaptionProcessingDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.raw_caption: str = None
+ self.raw_caption_short: str = None
+ self.caption: str = None
+ self.caption_short: str = None
+
+ dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
+ self.extra_values: List[float] = dataset_config.extra_values
+
+ # todo allow for loading from sd-scripts style dict
+ def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
+ if self.raw_caption is not None:
+ # we already loaded it
+ pass
+ elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
+ self.raw_caption = caption_dict[self.path]["caption"]
+ if 'caption_short' in caption_dict[self.path]:
+ self.raw_caption_short = caption_dict[self.path]["caption_short"]
+ else:
+ # see if prompt file exists
+ path_no_ext = os.path.splitext(self.path)[0]
+ prompt_ext = self.dataset_config.caption_ext
+ prompt_path = f"{path_no_ext}.{prompt_ext}"
+ short_caption = None
+
+ if os.path.exists(prompt_path):
+ with open(prompt_path, 'r', encoding='utf-8') as f:
+ prompt = f.read()
+ short_caption = None
+ if prompt_path.endswith('.json'):
+ # replace any line endings with commas for \n \r \r\n
+ prompt = prompt.replace('\r\n', ' ')
+ prompt = prompt.replace('\n', ' ')
+ prompt = prompt.replace('\r', ' ')
+
+ prompt_json = json.loads(prompt)
+ if 'caption' in prompt_json:
+ prompt = prompt_json['caption']
+ if 'caption_short' in prompt_json:
+ short_caption = prompt_json['caption_short']
+
+ if 'extra_values' in prompt_json:
+ self.extra_values = prompt_json['extra_values']
+
+ prompt = clean_caption(prompt)
+ if short_caption is not None:
+ short_caption = clean_caption(short_caption)
+ else:
+ prompt = ''
+ if self.dataset_config.default_caption is not None:
+ prompt = self.dataset_config.default_caption
+
+ if short_caption is None:
+ short_caption = self.dataset_config.default_caption
+ self.raw_caption = prompt
+ self.raw_caption_short = short_caption
+
+ self.caption = self.get_caption()
+ if self.raw_caption_short is not None:
+ self.caption_short = self.get_caption(short_caption=True)
+
+ def get_caption(
+ self: 'FileItemDTO',
+ trigger=None,
+ to_replace_list=None,
+ add_if_not_present=False,
+ short_caption=False
+ ):
+ if short_caption:
+ raw_caption = self.raw_caption_short
+ else:
+ raw_caption = self.raw_caption
+ if raw_caption is None:
+ raw_caption = ''
+ # handle dropout
+ if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
+ # get a random float form 0 to 1
+ rand = random.random()
+ if rand < self.dataset_config.caption_dropout_rate:
+ # drop the caption
+ return ''
+
+ # get tokens
+ token_list = raw_caption.split(',')
+ # trim whitespace
+ token_list = [x.strip() for x in token_list]
+ # remove empty strings
+ token_list = [x for x in token_list if x]
+
+ # handle token dropout
+ if self.dataset_config.token_dropout_rate > 0 and not short_caption:
+ new_token_list = []
+ keep_tokens: int = self.dataset_config.keep_tokens
+ for idx, token in enumerate(token_list):
+ if idx < keep_tokens:
+ new_token_list.append(token)
+ elif self.dataset_config.token_dropout_rate >= 1.0:
+ # drop the token
+ pass
+ else:
+ # get a random float form 0 to 1
+ rand = random.random()
+ if rand > self.dataset_config.token_dropout_rate:
+ # keep the token
+ new_token_list.append(token)
+ token_list = new_token_list
+
+ if self.dataset_config.shuffle_tokens:
+ random.shuffle(token_list)
+
+ # join back together
+ caption = ', '.join(token_list)
+ # caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
+
+ if self.dataset_config.random_triggers:
+ num_triggers = self.dataset_config.random_triggers_max
+ if num_triggers > 1:
+ num_triggers = random.randint(0, num_triggers)
+
+ if num_triggers > 0:
+ triggers = random.sample(self.dataset_config.random_triggers, num_triggers)
+ caption = caption + ', ' + ', '.join(triggers)
+ # add random triggers
+ # for i in range(num_triggers):
+ # # fastest method
+ # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))]
+ # caption = caption + ', ' + trigger
+
+ if self.dataset_config.shuffle_tokens:
+ # shuffle again
+ token_list = caption.split(',')
+ # trim whitespace
+ token_list = [x.strip() for x in token_list]
+ # remove empty strings
+ token_list = [x for x in token_list if x]
+ random.shuffle(token_list)
+ caption = ', '.join(token_list)
+
+ return caption
+
+
+class ImageProcessingDTOMixin:
+ def load_and_process_image(
+ self: 'FileItemDTO',
+ transform: Union[None, transforms.Compose],
+ only_load_latents=False
+ ):
+ # if we are caching latents, just do that
+ if self.is_latent_cached:
+ self.get_latent()
+ if self.has_control_image:
+ self.load_control_image()
+ if self.has_clip_image:
+ self.load_clip_image()
+ if self.has_mask_image:
+ self.load_mask_image()
+ if self.has_unconditional:
+ self.load_unconditional_image()
+ return
+ try:
+ img = Image.open(self.path)
+ img = exif_transpose(img)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Error loading image: {self.path}")
+
+ if self.use_alpha_as_mask:
+ # we do this to make sure it does not replace the alpha with another color
+ # we want the image just without the alpha channel
+ np_img = np.array(img)
+ # strip off alpha
+ np_img = np_img[:, :, :3]
+ img = Image.fromarray(np_img)
+
+ img = img.convert('RGB')
+ w, h = img.size
+ if w > h and self.scale_to_width < self.scale_to_height:
+ # throw error, they should match
+ print(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+ elif h > w and self.scale_to_height < self.scale_to_width:
+ # throw error, they should match
+ print(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+
+ if self.flip_x:
+ # do a flip
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if self.flip_y:
+ # do a flip
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ if self.dataset_config.buckets:
+ # scale and crop based on file item
+ img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
+ # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
+ if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
+ # todo look into this. This still happens sometimes
+ print('size mismatch')
+ img = img.crop((
+ self.crop_x,
+ self.crop_y,
+ self.crop_x + self.crop_width,
+ self.crop_y + self.crop_height
+ ))
+
+ # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
+ else:
+ # Downscale the source image first
+ # TODO this is nto right
+ img = img.resize(
+ (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
+ Image.BICUBIC)
+ min_img_size = min(img.size)
+ if self.dataset_config.random_crop:
+ if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
+ if min_img_size < self.dataset_config.resolution:
+ print(
+ f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
+ scale_size = self.dataset_config.resolution
+ else:
+ scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
+ scaler = scale_size / min_img_size
+ scale_width = int((img.width + 5) * scaler)
+ scale_height = int((img.height + 5) * scaler)
+ img = img.resize((scale_width, scale_height), Image.BICUBIC)
+ img = transforms.RandomCrop(self.dataset_config.resolution)(img)
+ else:
+ img = transforms.CenterCrop(min_img_size)(img)
+ img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
+
+ if self.augments is not None and len(self.augments) > 0:
+ # do augmentations
+ for augment in self.augments:
+ if augment in transforms_dict:
+ img = transforms_dict[augment](img)
+
+ if self.has_augmentations:
+ # augmentations handles transforms
+ img = self.augment_image(img, transform=transform)
+ elif transform:
+ img = transform(img)
+
+ self.tensor = img
+ if not only_load_latents:
+ if self.has_control_image:
+ self.load_control_image()
+ if self.has_clip_image:
+ self.load_clip_image()
+ if self.has_mask_image:
+ self.load_mask_image()
+ if self.has_unconditional:
+ self.load_unconditional_image()
+
+
+class ControlFileItemDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.has_control_image = False
+ self.control_path: Union[str, None] = None
+ self.control_tensor: Union[torch.Tensor, None] = None
+ dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+ self.full_size_control_images = False
+ if dataset_config.control_path is not None:
+ # find the control image path
+ control_path = dataset_config.control_path
+ self.full_size_control_images = dataset_config.full_size_control_images
+ # we are using control images
+ img_path = kwargs.get('path', None)
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
+ for ext in img_ext_list:
+ if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)):
+ self.control_path = os.path.join(control_path, file_name_no_ext + ext)
+ self.has_control_image = True
+ break
+
+ def load_control_image(self: 'FileItemDTO'):
+ try:
+ img = Image.open(self.control_path).convert('RGB')
+ img = exif_transpose(img)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Error loading image: {self.control_path}")
+
+ if self.full_size_control_images:
+ # we just scale them to 512x512:
+ w, h = img.size
+ img = img.resize((512, 512), Image.BICUBIC)
+
+ else:
+ w, h = img.size
+ if w > h and self.scale_to_width < self.scale_to_height:
+ # throw error, they should match
+ raise ValueError(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+ elif h > w and self.scale_to_height < self.scale_to_width:
+ # throw error, they should match
+ raise ValueError(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+
+ if self.flip_x:
+ # do a flip
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if self.flip_y:
+ # do a flip
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ if self.dataset_config.buckets:
+ # scale and crop based on file item
+ img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
+ # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
+ # crop
+ img = img.crop((
+ self.crop_x,
+ self.crop_y,
+ self.crop_x + self.crop_width,
+ self.crop_y + self.crop_height
+ ))
+ else:
+ raise Exception("Control images not supported for non-bucket datasets")
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ if self.aug_replay_spatial_transforms:
+ self.control_tensor = self.augment_spatial_control(img, transform=transform)
+ else:
+ self.control_tensor = transform(img)
+
+ def cleanup_control(self: 'FileItemDTO'):
+ self.control_tensor = None
+
+
+class ClipImageFileItemDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.has_clip_image = False
+ self.clip_image_path: Union[str, None] = None
+ self.clip_image_tensor: Union[torch.Tensor, None] = None
+ self.clip_image_embeds: Union[dict, None] = None
+ self.clip_image_embeds_unconditional: Union[dict, None] = None
+ self.has_clip_augmentations = False
+ self.clip_image_aug_transform: Union[None, A.Compose] = None
+ self.clip_image_processor: Union[None, CLIPImageProcessor] = None
+ self.clip_image_encoder_path: Union[str, None] = None
+ self.is_caching_clip_vision_to_disk = False
+ self.is_vision_clip_cached = False
+ self.clip_vision_is_quad = False
+ self.clip_vision_load_device = 'cpu'
+ self.clip_vision_unconditional_paths: Union[List[str], None] = None
+ self._clip_vision_embeddings_path: Union[str, None] = None
+ dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+ if dataset_config.clip_image_path is not None or dataset_config.clip_image_from_same_folder:
+ # copy the clip image processor so the dataloader can do it
+ sd = kwargs.get('sd', None)
+ if hasattr(sd.adapter, 'clip_image_processor'):
+ self.clip_image_processor = sd.adapter.clip_image_processor
+ if dataset_config.clip_image_path is not None:
+ # find the control image path
+ clip_image_path = dataset_config.clip_image_path
+ # we are using control images
+ img_path = kwargs.get('path', None)
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
+ for ext in img_ext_list:
+ if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)):
+ self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
+ self.has_clip_image = True
+ break
+ self.build_clip_imag_augmentation_transform()
+
+ if dataset_config.clip_image_from_same_folder:
+ # assume we have one. We will pull it on load.
+ self.has_clip_image = True
+ self.build_clip_imag_augmentation_transform()
+
+ def build_clip_imag_augmentation_transform(self: 'FileItemDTO'):
+ if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0:
+ self.has_clip_augmentations = True
+ augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations]
+
+ if self.dataset_config.clip_image_shuffle_augmentations:
+ random.shuffle(augmentations)
+
+ augmentation_list = []
+ for aug in augmentations:
+ # make sure method name is valid
+ assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
+ # get the method
+ method = getattr(A, aug.method_name)
+ # add the method to the list
+ augmentation_list.append(method(**aug.params))
+
+ self.clip_image_aug_transform = A.Compose(augmentation_list)
+
+ def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
+ if self.dataset_config.clip_image_shuffle_augmentations:
+ self.build_clip_imag_augmentation_transform()
+
+ open_cv_image = np.array(img)
+ # Convert RGB to BGR
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
+
+ if self.clip_vision_is_quad:
+ # image is in a 2x2 gris. split, run augs, and recombine
+ # split
+ img1, img2 = np.hsplit(open_cv_image, 2)
+ img1_1, img1_2 = np.vsplit(img1, 2)
+ img2_1, img2_2 = np.vsplit(img2, 2)
+ # apply augmentations
+ img1_1 = self.clip_image_aug_transform(image=img1_1)["image"]
+ img1_2 = self.clip_image_aug_transform(image=img1_2)["image"]
+ img2_1 = self.clip_image_aug_transform(image=img2_1)["image"]
+ img2_2 = self.clip_image_aug_transform(image=img2_2)["image"]
+ # recombine
+ augmented = np.vstack((np.hstack((img1_1, img1_2)), np.hstack((img2_1, img2_2))))
+
+ else:
+ # apply augmentations
+ augmented = self.clip_image_aug_transform(image=open_cv_image)["image"]
+
+ # convert back to RGB tensor
+ augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
+
+ # convert to PIL image
+ augmented = Image.fromarray(augmented)
+
+ augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
+
+ return augmented_tensor
+
+ def get_clip_vision_info_dict(self: 'FileItemDTO'):
+ item = OrderedDict([
+ ("image_encoder_path", self.clip_image_encoder_path),
+ ("filename", os.path.basename(self.clip_image_path)),
+ ("is_quad", self.clip_vision_is_quad)
+ ])
+ # when adding items, do it after so we dont change old latents
+ if self.flip_x:
+ item["flip_x"] = True
+ if self.flip_y:
+ item["flip_y"] = True
+ return item
+ def get_clip_vision_embeddings_path(self: 'FileItemDTO', recalculate=False):
+ if self._clip_vision_embeddings_path is not None and not recalculate:
+ return self._clip_vision_embeddings_path
+ else:
+ # we store latents in a folder in same path as image called _latent_cache
+ img_dir = os.path.dirname(self.clip_image_path)
+ latent_dir = os.path.join(img_dir, '_clip_vision_cache')
+ hash_dict = self.get_clip_vision_info_dict()
+ filename_no_ext = os.path.splitext(os.path.basename(self.clip_image_path))[0]
+ # get base64 hash of md5 checksum of hash_dict
+ hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
+ hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
+ hash_str = hash_str.replace('=', '')
+ self._clip_vision_embeddings_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
+
+ return self._clip_vision_embeddings_path
+
+ def get_new_clip_image_path(self: 'FileItemDTO'):
+ if self.dataset_config.clip_image_from_same_folder:
+ # randomly grab an image path from the same folder
+ pool_folder = os.path.dirname(self.path)
+ # find all images in the folder
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ img_files = []
+ for ext in img_ext_list:
+ img_files += glob.glob(os.path.join(pool_folder, f'*{ext}'))
+ # remove the current image if len is greater than 1
+ if len(img_files) > 1:
+ img_files.remove(self.path)
+ # randomly grab one
+ return random.choice(img_files)
+ else:
+ return self.clip_image_path
+
+ def load_clip_image(self: 'FileItemDTO'):
+ is_dynamic_size_and_aspect = isinstance(self.clip_image_processor, PixtralVisionImagePreprocessorCompatible) or \
+ isinstance(self.clip_image_processor, SiglipImageProcessor)
+ if self.is_vision_clip_cached:
+ self.clip_image_embeds = load_file(self.get_clip_vision_embeddings_path())
+
+ # get a random unconditional image
+ if self.clip_vision_unconditional_paths is not None:
+ unconditional_path = random.choice(self.clip_vision_unconditional_paths)
+ self.clip_image_embeds_unconditional = load_file(unconditional_path)
+
+ return
+ clip_image_path = self.get_new_clip_image_path()
+ try:
+ img = Image.open(clip_image_path).convert('RGB')
+ img = exif_transpose(img)
+ except Exception as e:
+ # make a random noise image
+ img = Image.new('RGB', (self.dataset_config.resolution, self.dataset_config.resolution))
+ print(f"Error: {e}")
+ print(f"Error loading image: {clip_image_path}")
+
+ img = img.convert('RGB')
+
+ if self.flip_x:
+ # do a flip
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if self.flip_y:
+ # do a flip
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ if is_dynamic_size_and_aspect:
+ pass # let the image processor handle it
+ elif img.width != img.height:
+ min_size = min(img.width, img.height)
+ if self.dataset_config.square_crop:
+ # center crop to a square
+ img = transforms.CenterCrop(min_size)(img)
+ else:
+ # image must be square. If it is not, we will resize/squish it so it is, that way we don't crop out data
+ # resize to the smallest dimension
+ img = img.resize((min_size, min_size), Image.BICUBIC)
+
+ if self.has_clip_augmentations:
+ self.clip_image_tensor = self.augment_clip_image(img, transform=None)
+ else:
+ self.clip_image_tensor = transforms.ToTensor()(img)
+
+ # random crop
+ # if self.dataset_config.clip_image_random_crop:
+ # # crop up to 20% on all sides. Keep is square
+ # crop_percent = random.randint(0, 20) / 100
+ # crop_width = int(self.clip_image_tensor.shape[2] * crop_percent)
+ # crop_height = int(self.clip_image_tensor.shape[1] * crop_percent)
+ # crop_left = random.randint(0, crop_width)
+ # crop_top = random.randint(0, crop_height)
+ # crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left
+ # crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top
+ # if len(self.clip_image_tensor.shape) == 3:
+ # self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right]
+ # elif len(self.clip_image_tensor.shape) == 4:
+ # self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right]
+
+ if self.clip_image_processor is not None:
+ # run it
+ tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)
+ clip_out = self.clip_image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+ self.clip_image_tensor = clip_out.squeeze(0).clone().detach()
+
+ def cleanup_clip_image(self: 'FileItemDTO'):
+ self.clip_image_tensor = None
+ self.clip_image_embeds = None
+
+
+
+
+class AugmentationFileItemDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.has_augmentations = False
+ self.unaugmented_tensor: Union[torch.Tensor, None] = None
+ # self.augmentations: Union[None, List[Augments]] = None
+ self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+ self.aug_transform: Union[None, A.Compose] = None
+ self.aug_replay_spatial_transforms = None
+ self.build_augmentation_transform()
+
+ def build_augmentation_transform(self: 'FileItemDTO'):
+ if self.dataset_config.augmentations is not None and len(self.dataset_config.augmentations) > 0:
+ self.has_augmentations = True
+ augmentations = [Augments(**aug) for aug in self.dataset_config.augmentations]
+
+ if self.dataset_config.shuffle_augmentations:
+ random.shuffle(augmentations)
+
+ augmentation_list = []
+ for aug in augmentations:
+ # make sure method name is valid
+ assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
+ # get the method
+ method = getattr(A, aug.method_name)
+ # add the method to the list
+ augmentation_list.append(method(**aug.params))
+
+ # add additional targets so we can augment the control image
+ self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'})
+
+ def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
+
+ # rebuild each time if shuffle
+ if self.dataset_config.shuffle_augmentations:
+ self.build_augmentation_transform()
+
+ # save the original tensor
+ self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
+
+ open_cv_image = np.array(img)
+ # Convert RGB to BGR
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
+
+ # apply augmentations
+ transformed = self.aug_transform(image=open_cv_image)
+ augmented = transformed["image"]
+
+ # save just the spatial transforms for controls and masks
+ augmented_params = transformed["replay"]
+ spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop',
+ 'ElasticTransform', 'GridDistortion', 'OpticalDistortion']
+ # only store the spatial transforms
+ augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms]
+
+ if self.dataset_config.replay_transforms:
+ self.aug_replay_spatial_transforms = augmented_params
+
+ # convert back to RGB tensor
+ augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
+
+ # convert to PIL image
+ augmented = Image.fromarray(augmented)
+
+ augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
+
+ return augmented_tensor
+
+ # augment control images spatially consistent with transforms done to the main image
+ def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ):
+ if self.aug_replay_spatial_transforms is None:
+ # no transforms
+ return transform(img)
+
+ # save colorspace to convert back to
+ colorspace = img.mode
+
+ # convert to rgb
+ img = img.convert('RGB')
+
+ open_cv_image = np.array(img)
+ # Convert RGB to BGR
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
+
+ # Replay transforms
+ transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image)
+ augmented = transformed["image"]
+
+ # convert back to RGB tensor
+ augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
+
+ # convert to PIL image
+ augmented = Image.fromarray(augmented)
+
+ # convert back to original colorspace
+ augmented = augmented.convert(colorspace)
+
+ augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
+ return augmented_tensor
+
+ def cleanup_control(self: 'FileItemDTO'):
+ self.unaugmented_tensor = None
+
+
+class MaskFileItemDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.has_mask_image = False
+ self.mask_path: Union[str, None] = None
+ self.mask_tensor: Union[torch.Tensor, None] = None
+ self.use_alpha_as_mask: bool = False
+ dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+ self.mask_min_value = dataset_config.mask_min_value
+ if dataset_config.alpha_mask:
+ self.use_alpha_as_mask = True
+ self.mask_path = kwargs.get('path', None)
+ self.has_mask_image = True
+ elif dataset_config.mask_path is not None:
+ # find the control image path
+ mask_path = dataset_config.mask_path if dataset_config.mask_path is not None else dataset_config.alpha_mask
+ # we are using control images
+ img_path = kwargs.get('path', None)
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
+ for ext in img_ext_list:
+ if os.path.exists(os.path.join(mask_path, file_name_no_ext + ext)):
+ self.mask_path = os.path.join(mask_path, file_name_no_ext + ext)
+ self.has_mask_image = True
+ break
+
+ def load_mask_image(self: 'FileItemDTO'):
+ try:
+ img = Image.open(self.mask_path)
+ img = exif_transpose(img)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Error loading image: {self.mask_path}")
+
+ if self.use_alpha_as_mask:
+ # pipeline expectws an rgb image so we need to put alpha in all channels
+ np_img = np.array(img)
+ np_img[:, :, :3] = np_img[:, :, 3:]
+
+ np_img = np_img[:, :, :3]
+ img = Image.fromarray(np_img)
+
+ img = img.convert('RGB')
+ if self.dataset_config.invert_mask:
+ img = ImageOps.invert(img)
+ w, h = img.size
+ fix_size = False
+ if w > h and self.scale_to_width < self.scale_to_height:
+ # throw error, they should match
+ print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+ fix_size = True
+ elif h > w and self.scale_to_height < self.scale_to_width:
+ # throw error, they should match
+ print(f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+ fix_size = True
+
+ if fix_size:
+ # swap all the sizes
+ self.scale_to_width, self.scale_to_height = self.scale_to_height, self.scale_to_width
+ self.crop_width, self.crop_height = self.crop_height, self.crop_width
+ self.crop_x, self.crop_y = self.crop_y, self.crop_x
+
+
+
+
+ if self.flip_x:
+ # do a flip
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if self.flip_y:
+ # do a flip
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ # randomly apply a blur up to 0.5% of the size of the min (width, height)
+ min_size = min(img.width, img.height)
+ blur_radius = int(min_size * random.random() * 0.005)
+ img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
+
+ # make grayscale
+ img = img.convert('L')
+
+ if self.dataset_config.buckets:
+ # scale and crop based on file item
+ img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
+ # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
+ # crop
+ img = img.crop((
+ self.crop_x,
+ self.crop_y,
+ self.crop_x + self.crop_width,
+ self.crop_y + self.crop_height
+ ))
+ else:
+ raise Exception("Mask images not supported for non-bucket datasets")
+
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ if self.aug_replay_spatial_transforms:
+ self.mask_tensor = self.augment_spatial_control(img, transform=transform)
+ else:
+ self.mask_tensor = transform(img)
+ self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0)
+ # convert to grayscale
+
+ def cleanup_mask(self: 'FileItemDTO'):
+ self.mask_tensor = None
+
+
+class UnconditionalFileItemDTOMixin:
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self.has_unconditional = False
+ self.unconditional_path: Union[str, None] = None
+ self.unconditional_tensor: Union[torch.Tensor, None] = None
+ self.unconditional_latent: Union[torch.Tensor, None] = None
+ self.unconditional_transforms = self.dataloader_transforms
+ dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
+
+ if dataset_config.unconditional_path is not None:
+ # we are using control images
+ img_path = kwargs.get('path', None)
+ img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
+ file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
+ for ext in img_ext_list:
+ if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)):
+ self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)
+ self.has_unconditional = True
+ break
+
+ def load_unconditional_image(self: 'FileItemDTO'):
+ try:
+ img = Image.open(self.unconditional_path)
+ img = exif_transpose(img)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Error loading image: {self.mask_path}")
+
+ img = img.convert('RGB')
+ w, h = img.size
+ if w > h and self.scale_to_width < self.scale_to_height:
+ # throw error, they should match
+ raise ValueError(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+ elif h > w and self.scale_to_height < self.scale_to_width:
+ # throw error, they should match
+ raise ValueError(
+ f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
+
+ if self.flip_x:
+ # do a flip
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
+ if self.flip_y:
+ # do a flip
+ img = img.transpose(Image.FLIP_TOP_BOTTOM)
+
+ if self.dataset_config.buckets:
+ # scale and crop based on file item
+ img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
+ # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
+ # crop
+ img = img.crop((
+ self.crop_x,
+ self.crop_y,
+ self.crop_x + self.crop_width,
+ self.crop_y + self.crop_height
+ ))
+ else:
+ raise Exception("Unconditional images are not supported for non-bucket datasets")
+
+ if self.aug_replay_spatial_transforms:
+ self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms)
+ else:
+ self.unconditional_tensor = self.unconditional_transforms(img)
+
+ def cleanup_unconditional(self: 'FileItemDTO'):
+ self.unconditional_tensor = None
+ self.unconditional_latent = None
+
+
+class PoiFileItemDTOMixin:
+ # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
+ # items in the poi will always be inside the image when random cropping
+ def __init__(self: 'FileItemDTO', *args, **kwargs):
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ # poi is a name of the box point of interest in the caption json file
+ dataset_config = kwargs.get('dataset_config', None)
+ path = kwargs.get('path', None)
+ self.poi: Union[str, None] = dataset_config.poi
+ self.has_point_of_interest = self.poi is not None
+ self.poi_x: Union[int, None] = None
+ self.poi_y: Union[int, None] = None
+ self.poi_width: Union[int, None] = None
+ self.poi_height: Union[int, None] = None
+
+ if self.poi is not None:
+ # make sure latent caching is off
+ if dataset_config.cache_latents or dataset_config.cache_latents_to_disk:
+ raise Exception(
+ f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config"
+ )
+ # make sure we are loading through json
+ if dataset_config.caption_ext != 'json':
+ raise Exception(
+ f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config"
+ )
+ self.poi = self.poi.strip()
+ # get the caption path
+ file_path_no_ext = os.path.splitext(path)[0]
+ caption_path = file_path_no_ext + '.json'
+ if not os.path.exists(caption_path):
+ raise Exception(f"Error: caption file not found for poi: {caption_path}")
+ with open(caption_path, 'r', encoding='utf-8') as f:
+ json_data = json.load(f)
+ if 'poi' not in json_data:
+ print(f"Warning: poi not found in caption file: {caption_path}")
+ if self.poi not in json_data['poi']:
+ print(f"Warning: poi not found in caption file: {caption_path}")
+ # poi has, x, y, width, height
+ # do full image if no poi
+ self.poi_x = 0
+ self.poi_y = 0
+ self.poi_width = self.width
+ self.poi_height = self.height
+ try:
+ if self.poi in json_data['poi']:
+ poi = json_data['poi'][self.poi]
+ self.poi_x = int(poi['x'])
+ self.poi_y = int(poi['y'])
+ self.poi_width = int(poi['width'])
+ self.poi_height = int(poi['height'])
+ except Exception as e:
+ pass
+
+ # handle flipping
+ if kwargs.get('flip_x', False):
+ # flip the poi
+ self.poi_x = self.width - self.poi_x - self.poi_width
+ if kwargs.get('flip_y', False):
+ # flip the poi
+ self.poi_y = self.height - self.poi_y - self.poi_height
+
+ def setup_poi_bucket(self: 'FileItemDTO'):
+ initial_width = int(self.width * self.dataset_config.scale)
+ initial_height = int(self.height * self.dataset_config.scale)
+ # we are using poi, so we need to calculate the bucket based on the poi
+
+ # if img resolution is less than dataset resolution, just return and let the normal bucketing happen
+ img_resolution = get_resolution(initial_width, initial_height)
+ if img_resolution <= self.dataset_config.resolution:
+ return False # will trigger normal bucketing
+
+ bucket_tolerance = self.dataset_config.bucket_tolerance
+ poi_x = int(self.poi_x * self.dataset_config.scale)
+ poi_y = int(self.poi_y * self.dataset_config.scale)
+ poi_width = int(self.poi_width * self.dataset_config.scale)
+ poi_height = int(self.poi_height * self.dataset_config.scale)
+
+ # loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
+ num_loops = 0
+ while True:
+ # crop left
+ if poi_x > 0:
+ poi_x = random.randint(0, poi_x)
+ else:
+ poi_x = 0
+
+ # crop right
+ cr_min = poi_x + poi_width
+ if cr_min < initial_width:
+ crop_right = random.randint(poi_x + poi_width, initial_width)
+ else:
+ crop_right = initial_width
+
+ poi_width = crop_right - poi_x
+
+ if poi_y > 0:
+ poi_y = random.randint(0, poi_y)
+ else:
+ poi_y = 0
+
+ if poi_y + poi_height < initial_height:
+ crop_bottom = random.randint(poi_y + poi_height, initial_height)
+ else:
+ crop_bottom = initial_height
+
+ poi_height = crop_bottom - poi_y
+ try:
+ # now we have our random crop, but it may be smaller than resolution. Check and expand if needed
+ current_resolution = get_resolution(poi_width, poi_height)
+ except Exception as e:
+ print(f"Error: {e}")
+ print(f"Error getting resolution: {self.path}")
+ raise e
+ return False
+ if current_resolution >= self.dataset_config.resolution:
+ # We can break now
+ break
+ else:
+ num_loops += 1
+ if num_loops > 100:
+ print(
+ f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
+ return False
+
+ new_width = poi_width
+ new_height = poi_height
+
+ bucket_resolution = get_bucket_for_image_size(
+ new_width, new_height,
+ resolution=self.dataset_config.resolution,
+ divisibility=bucket_tolerance
+ )
+
+ width_scale_factor = bucket_resolution["width"] / new_width
+ height_scale_factor = bucket_resolution["height"] / new_height
+ # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
+ max_scale_factor = max(width_scale_factor, height_scale_factor)
+
+ self.scale_to_width = math.ceil(initial_width * max_scale_factor)
+ self.scale_to_height = math.ceil(initial_height * max_scale_factor)
+ self.crop_width = bucket_resolution['width']
+ self.crop_height = bucket_resolution['height']
+ self.crop_x = int(poi_x * max_scale_factor)
+ self.crop_y = int(poi_y * max_scale_factor)
+
+ if self.scale_to_width < self.crop_x + self.crop_width or self.scale_to_height < self.crop_y + self.crop_height:
+ # todo look into this. This still happens sometimes
+ print('size mismatch')
+
+ return True
+
+
+class ArgBreakMixin:
+ # just stops super calls form hitting object
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+class LatentCachingFileItemDTOMixin:
+ def __init__(self, *args, **kwargs):
+ # if we have super, call it
+ if hasattr(super(), '__init__'):
+ super().__init__(*args, **kwargs)
+ self._encoded_latent: Union[torch.Tensor, None] = None
+ self._latent_path: Union[str, None] = None
+ self.is_latent_cached = False
+ self.is_caching_to_disk = False
+ self.is_caching_to_memory = False
+ self.latent_load_device = 'cpu'
+ # sd1 or sdxl or others
+ self.latent_space_version = 'sd1'
+ # todo, increment this if we change the latent format to invalidate cache
+ self.latent_version = 1
+
+ def get_latent_info_dict(self: 'FileItemDTO'):
+ item = OrderedDict([
+ ("filename", os.path.basename(self.path)),
+ ("scale_to_width", self.scale_to_width),
+ ("scale_to_height", self.scale_to_height),
+ ("crop_x", self.crop_x),
+ ("crop_y", self.crop_y),
+ ("crop_width", self.crop_width),
+ ("crop_height", self.crop_height),
+ ("latent_space_version", self.latent_space_version),
+ ("latent_version", self.latent_version),
+ ])
+ # when adding items, do it after so we dont change old latents
+ if self.flip_x:
+ item["flip_x"] = True
+ if self.flip_y:
+ item["flip_y"] = True
+ return item
+
+ def get_latent_path(self: 'FileItemDTO', recalculate=False):
+ if self._latent_path is not None and not recalculate:
+ return self._latent_path
+ else:
+ # we store latents in a folder in same path as image called _latent_cache
+ img_dir = os.path.dirname(self.path)
+ latent_dir = os.path.join(img_dir, '_latent_cache')
+ hash_dict = self.get_latent_info_dict()
+ filename_no_ext = os.path.splitext(os.path.basename(self.path))[0]
+ # get base64 hash of md5 checksum of hash_dict
+ hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
+ hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
+ hash_str = hash_str.replace('=', '')
+ self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors')
+
+ return self._latent_path
+
+ def cleanup_latent(self):
+ if self._encoded_latent is not None:
+ if not self.is_caching_to_memory:
+ # we are caching on disk, don't save in memory
+ self._encoded_latent = None
+ else:
+ # move it back to cpu
+ self._encoded_latent = self._encoded_latent.to('cpu')
+
+ def get_latent(self, device=None):
+ if not self.is_latent_cached:
+ return None
+ if self._encoded_latent is None:
+ # load it from disk
+ state_dict = load_file(
+ self.get_latent_path(),
+ # device=device if device is not None else self.latent_load_device
+ device='cpu'
+ )
+ self._encoded_latent = state_dict['latent']
+ return self._encoded_latent
+
+
+class LatentCachingMixin:
+ def __init__(self: 'AiToolkitDataset', **kwargs):
+ # if we have super, call it
+ if hasattr(super(), '__init__'):
+ super().__init__(**kwargs)
+ self.latent_cache = {}
+
+ def cache_latents_all_latents(self: 'AiToolkitDataset'):
+ print(f"Caching latents for {self.dataset_path}")
+ # cache all latents to disk
+ to_disk = self.is_caching_latents_to_disk
+ to_memory = self.is_caching_latents_to_memory
+
+ if to_disk:
+ print(" - Saving latents to disk")
+ if to_memory:
+ print(" - Keeping latents in memory")
+ # move sd items to cpu except for vae
+ self.sd.set_device_state_preset('cache_latents')
+
+ # use tqdm to show progress
+ i = 0
+ for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'):
+ # set latent space version
+ if self.sd.model_config.latent_space_version is not None:
+ file_item.latent_space_version = self.sd.model_config.latent_space_version
+ elif self.sd.is_xl:
+ file_item.latent_space_version = 'sdxl'
+ elif self.sd.is_v3:
+ file_item.latent_space_version = 'sd3'
+ elif self.sd.is_auraflow:
+ file_item.latent_space_version = 'sdxl'
+ elif self.sd.is_flux:
+ file_item.latent_space_version = 'flux1'
+ elif self.sd.model_config.is_pixart_sigma:
+ file_item.latent_space_version = 'sdxl'
+ else:
+ file_item.latent_space_version = 'sd1'
+ file_item.is_caching_to_disk = to_disk
+ file_item.is_caching_to_memory = to_memory
+ file_item.latent_load_device = self.sd.device
+
+ latent_path = file_item.get_latent_path(recalculate=True)
+ # check if it is saved to disk already
+ if os.path.exists(latent_path):
+ if to_memory:
+ # load it into memory
+ state_dict = load_file(latent_path, device='cpu')
+ file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype)
+ else:
+ # not saved to disk, calculate
+ # load the image first
+ file_item.load_and_process_image(self.transform, only_load_latents=True)
+ dtype = self.sd.torch_dtype
+ device = self.sd.device_torch
+ # add batch dimension
+ try:
+ imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype)
+ latent = self.sd.encode_images(imgs).squeeze(0)
+ except Exception as e:
+ print(f"Error processing image: {file_item.path}")
+ print(f"Error: {str(e)}")
+ raise e
+ # save_latent
+ if to_disk:
+ state_dict = OrderedDict([
+ ('latent', latent.clone().detach().cpu()),
+ ])
+ # metadata
+ meta = get_meta_for_safetensors(file_item.get_latent_info_dict())
+ os.makedirs(os.path.dirname(latent_path), exist_ok=True)
+ save_file(state_dict, latent_path, metadata=meta)
+
+ if to_memory:
+ # keep it in memory
+ file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype)
+
+ del imgs
+ del latent
+ del file_item.tensor
+
+ # flush(garbage_collect=False)
+ file_item.is_latent_cached = True
+ i += 1
+ # flush every 100
+ # if i % 100 == 0:
+ # flush()
+
+ # restore device state
+ self.sd.restore_device_state()
+
+
+class CLIPCachingMixin:
+ def __init__(self: 'AiToolkitDataset', **kwargs):
+ # if we have super, call it
+ if hasattr(super(), '__init__'):
+ super().__init__(**kwargs)
+ self.clip_vision_num_unconditional_cache = 20
+ self.clip_vision_unconditional_cache = []
+
+ def cache_clip_vision_to_disk(self: 'AiToolkitDataset'):
+ if not self.is_caching_clip_vision_to_disk:
+ return
+ with torch.no_grad():
+ print(f"Caching clip vision for {self.dataset_path}")
+
+ print(" - Saving clip to disk")
+ # move sd items to cpu except for vae
+ self.sd.set_device_state_preset('cache_clip')
+
+ # make sure the adapter has attributes
+ if self.sd.adapter is None:
+ raise Exception("Error: must have an adapter to cache clip vision to disk")
+
+ clip_image_processor: CLIPImageProcessor = None
+ if hasattr(self.sd.adapter, 'clip_image_processor'):
+ clip_image_processor = self.sd.adapter.clip_image_processor
+
+ if clip_image_processor is None:
+ raise Exception("Error: must have a clip image processor to cache clip vision to disk")
+
+ vision_encoder: CLIPVisionModelWithProjection = None
+ if hasattr(self.sd.adapter, 'image_encoder'):
+ vision_encoder = self.sd.adapter.image_encoder
+ if hasattr(self.sd.adapter, 'vision_encoder'):
+ vision_encoder = self.sd.adapter.vision_encoder
+
+ if vision_encoder is None:
+ raise Exception("Error: must have a vision encoder to cache clip vision to disk")
+
+ # move vision encoder to device
+ vision_encoder.to(self.sd.device)
+
+ is_quad = self.sd.adapter.config.quad_image
+ image_encoder_path = self.sd.adapter.config.image_encoder_path
+
+ dtype = self.sd.torch_dtype
+ device = self.sd.device_torch
+ if hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero:
+ # just to do this, we did :)
+ # need more samples as it is random noise
+ self.clip_vision_num_unconditional_cache = self.clip_vision_num_unconditional_cache
+ else:
+ # only need one since it doesnt change
+ self.clip_vision_num_unconditional_cache = 1
+
+ # cache unconditionals
+ print(f" - Caching {self.clip_vision_num_unconditional_cache} unconditional clip vision to disk")
+ clip_vision_cache_path = os.path.join(self.dataset_config.clip_image_path, '_clip_vision_cache')
+
+ unconditional_paths = []
+
+ is_noise_zero = hasattr(self.sd.adapter, 'clip_noise_zero') and self.sd.adapter.clip_noise_zero
+
+ for i in range(self.clip_vision_num_unconditional_cache):
+ hash_dict = OrderedDict([
+ ("image_encoder_path", image_encoder_path),
+ ("is_quad", is_quad),
+ ("is_noise_zero", is_noise_zero),
+ ])
+ # get base64 hash of md5 checksum of hash_dict
+ hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8')
+ hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii')
+ hash_str = hash_str.replace('=', '')
+
+ uncond_path = os.path.join(clip_vision_cache_path, f'uncond_{hash_str}_{i}.safetensors')
+ if os.path.exists(uncond_path):
+ # skip it
+ unconditional_paths.append(uncond_path)
+ continue
+
+ # generate a random image
+ img_shape = (1, 3, self.sd.adapter.input_size, self.sd.adapter.input_size)
+ if is_noise_zero:
+ tensors_0_1 = torch.rand(img_shape).to(device, dtype=torch.float32)
+ else:
+ tensors_0_1 = torch.zeros(img_shape).to(device, dtype=torch.float32)
+ clip_image = clip_image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+
+ if is_quad:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = clip_image.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
+
+ clip_output = vision_encoder(
+ clip_image.to(device, dtype=dtype),
+ output_hidden_states=True
+ )
+ # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
+ state_dict = OrderedDict([
+ ('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
+ ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
+ ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
+ ])
+
+ os.makedirs(os.path.dirname(uncond_path), exist_ok=True)
+ save_file(state_dict, uncond_path)
+ unconditional_paths.append(uncond_path)
+
+ self.clip_vision_unconditional_cache = unconditional_paths
+
+ # use tqdm to show progress
+ i = 0
+ for file_item in tqdm(self.file_list, desc=f'Caching clip vision to disk'):
+ file_item.is_caching_clip_vision_to_disk = True
+ file_item.clip_vision_load_device = self.sd.device
+ file_item.clip_vision_is_quad = is_quad
+ file_item.clip_image_encoder_path = image_encoder_path
+ file_item.clip_vision_unconditional_paths = unconditional_paths
+ if file_item.has_clip_augmentations:
+ raise Exception("Error: clip vision caching is not supported with clip augmentations")
+
+ embedding_path = file_item.get_clip_vision_embeddings_path(recalculate=True)
+ # check if it is saved to disk already
+ if not os.path.exists(embedding_path):
+ # load the image first
+ file_item.load_clip_image()
+ # add batch dimension
+ clip_image = file_item.clip_image_tensor.unsqueeze(0).to(device, dtype=dtype)
+
+ if is_quad:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = clip_image.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ clip_image = torch.cat([ci1, ci2, ci3, ci4], dim=0).detach()
+
+ clip_output = vision_encoder(
+ clip_image.to(device, dtype=dtype),
+ output_hidden_states=True
+ )
+
+ # make state_dict ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
+ state_dict = OrderedDict([
+ ('image_embeds', clip_output.image_embeds.clone().detach().cpu()),
+ ('last_hidden_state', clip_output.hidden_states[-1].clone().detach().cpu()),
+ ('penultimate_hidden_states', clip_output.hidden_states[-2].clone().detach().cpu()),
+ ])
+ # metadata
+ meta = get_meta_for_safetensors(file_item.get_clip_vision_info_dict())
+ os.makedirs(os.path.dirname(embedding_path), exist_ok=True)
+ save_file(state_dict, embedding_path, metadata=meta)
+
+ del clip_image
+ del clip_output
+ del file_item.clip_image_tensor
+
+ # flush(garbage_collect=False)
+ file_item.is_vision_clip_cached = True
+ i += 1
+ # flush every 100
+ # if i % 100 == 0:
+ # flush()
+
+ # restore device state
+ self.sd.restore_device_state()
diff --git a/toolkit/dequantize.py b/toolkit/dequantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..54c8ec7b29862efa11b7fc3c9dc1efc8c1d66423
--- /dev/null
+++ b/toolkit/dequantize.py
@@ -0,0 +1,88 @@
+
+
+from functools import partial
+from optimum.quanto.tensor import QTensor
+import torch
+
+
+def hacked_state_dict(self, *args, **kwargs):
+ orig_state_dict = self.orig_state_dict(*args, **kwargs)
+ new_state_dict = {}
+ for key, value in orig_state_dict.items():
+ if key.endswith("._scale"):
+ continue
+ if key.endswith(".input_scale"):
+ continue
+ if key.endswith(".output_scale"):
+ continue
+ if key.endswith("._data"):
+ key = key[:-6]
+ scale = orig_state_dict[key + "._scale"]
+ # scale is the original dtype
+ dtype = scale.dtype
+ scale = scale.float()
+ value = value.float()
+ dequantized = value * scale
+
+ # handle input and output scaling if they exist
+ input_scale = orig_state_dict.get(key + ".input_scale")
+
+ if input_scale is not None:
+ # make sure the tensor is 1.0
+ if input_scale.item() != 1.0:
+ raise ValueError("Input scale is not 1.0, cannot dequantize")
+
+ output_scale = orig_state_dict.get(key + ".output_scale")
+
+ if output_scale is not None:
+ # make sure the tensor is 1.0
+ if output_scale.item() != 1.0:
+ raise ValueError("Output scale is not 1.0, cannot dequantize")
+
+ new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
+ else:
+ new_state_dict[key] = value
+ return new_state_dict
+
+# hacks the state dict so we can dequantize before saving
+def patch_dequantization_on_save(model):
+ model.orig_state_dict = model.state_dict
+ model.state_dict = partial(hacked_state_dict, model)
+
+
+def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
+ """
+ Convert a quantized parameter back to a regular Parameter with floating point values.
+
+ Args:
+ module: The module containing the parameter to unquantize
+ param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
+
+ Returns:
+ bool: True if parameter was unquantized, False if it was already unquantized
+ """
+
+ # Check if the parameter exists
+ if not hasattr(module, param_name):
+ raise AttributeError(f"Module has no parameter named '{param_name}'")
+
+ param = getattr(module, param_name)
+
+ # If it's not a parameter or not quantized, nothing to do
+ if not isinstance(param, torch.nn.Parameter):
+ raise TypeError(f"'{param_name}' is not a Parameter")
+ if not isinstance(param, QTensor):
+ return False
+
+ # Convert to float tensor while preserving device and requires_grad
+ with torch.no_grad():
+ float_tensor = param.float()
+ new_param = torch.nn.Parameter(
+ float_tensor,
+ requires_grad=param.requires_grad
+ )
+
+ # Replace the parameter
+ setattr(module, param_name, new_param)
+
+ return True
\ No newline at end of file
diff --git a/toolkit/ema.py b/toolkit/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b3a7ea0974e37d783cb75e93d50659105bbb49
--- /dev/null
+++ b/toolkit/ema.py
@@ -0,0 +1,346 @@
+from __future__ import division
+from __future__ import unicode_literals
+
+from typing import Iterable, Optional
+import weakref
+import copy
+import contextlib
+from toolkit.optimizers.optimizer_utils import copy_stochastic
+
+import torch
+
+
+# Partially based on:
+# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
+class ExponentialMovingAverage:
+ """
+ Maintains (exponential) moving average of a set of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter` (typically from
+ `model.parameters()`).
+ Note that EMA is computed on *all* provided parameters,
+ regardless of whether or not they have `requires_grad = True`;
+ this allows a single EMA object to be consistantly used even
+ if which parameters are trainable changes step to step.
+
+ If you want to some parameters in the EMA, do not pass them
+ to the object in the first place. For example:
+
+ ExponentialMovingAverage(
+ parameters=[p for p in model.parameters() if p.requires_grad],
+ decay=0.9
+ )
+
+ will ignore parameters that do not require grad.
+
+ decay: The exponential decay.
+
+ use_num_updates: Whether to use number of updates when computing
+ averages.
+ """
+
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter] = None,
+ decay: float = 0.995,
+ use_num_updates: bool = False,
+ # feeds back the decat to the parameter
+ use_feedback: bool = False,
+ param_multiplier: float = 1.0
+ ):
+ if parameters is None:
+ raise ValueError("parameters must be provided")
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.decay = decay
+ self.num_updates = 0 if use_num_updates else None
+ self.use_feedback = use_feedback
+ self.param_multiplier = param_multiplier
+ parameters = list(parameters)
+ self.shadow_params = [
+ p.clone().detach()
+ for p in parameters
+ ]
+ self.collected_params = None
+ self._is_train_mode = True
+ # By maintaining only a weakref to each parameter,
+ # we maintain the old GC behaviour of ExponentialMovingAverage:
+ # if the model goes out of scope but the ExponentialMovingAverage
+ # is kept, no references to the model or its parameters will be
+ # maintained, and the model will be cleaned up.
+ self._params_refs = [weakref.ref(p) for p in parameters]
+
+ def _get_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]]
+ ) -> Iterable[torch.nn.Parameter]:
+ if parameters is None:
+ parameters = [p() for p in self._params_refs]
+ if any(p is None for p in parameters):
+ raise ValueError(
+ "(One of) the parameters with which this "
+ "ExponentialMovingAverage "
+ "was initialized no longer exists (was garbage collected);"
+ " please either provide `parameters` explicitly or keep "
+ "the model to which they belong from being garbage "
+ "collected."
+ )
+ return parameters
+ else:
+ parameters = list(parameters)
+ if len(parameters) != len(self.shadow_params):
+ raise ValueError(
+ "Number of parameters passed as argument is different "
+ "from number of shadow parameters maintained by this "
+ "ExponentialMovingAverage"
+ )
+ return parameters
+
+ def update(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Update currently maintained parameters.
+
+ Call this every time the parameters are updated, such as the result of
+ the `optimizer.step()` call.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ decay = self.decay
+ if self.num_updates is not None:
+ self.num_updates += 1
+ decay = min(
+ decay,
+ (1 + self.num_updates) / (10 + self.num_updates)
+ )
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ for s_param, param in zip(self.shadow_params, parameters):
+ s_param_float = s_param.float()
+ if s_param.dtype != torch.float32:
+ s_param_float = s_param_float.to(torch.float32)
+ param_float = param
+ if param.dtype != torch.float32:
+ param_float = param_float.to(torch.float32)
+ tmp = (s_param_float - param_float)
+ # tmp will be a new tensor so we can do in-place
+ tmp.mul_(one_minus_decay)
+ s_param_float.sub_(tmp)
+
+ update_param = False
+ if self.use_feedback:
+ param_float.add_(tmp)
+ update_param = True
+
+ if self.param_multiplier != 1.0:
+ param_float.mul_(self.param_multiplier)
+ update_param = True
+
+ if s_param.dtype != torch.float32:
+ copy_stochastic(s_param, s_param_float)
+
+ if update_param and param.dtype != torch.float32:
+ copy_stochastic(param, param_float)
+
+
+ def copy_to(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.data)
+
+ def store(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Save the current parameters for restoring later.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.collected_params = [
+ param.clone()
+ for param in parameters
+ ]
+
+ def restore(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ if self.collected_params is None:
+ raise RuntimeError(
+ "This ExponentialMovingAverage has no `store()`ed weights "
+ "to `restore()`"
+ )
+ parameters = self._get_parameters(parameters)
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
+
+ @contextlib.contextmanager
+ def average_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ self.restore(parameters)
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]
+ return
+
+ def state_dict(self) -> dict:
+ r"""Returns the state of the ExponentialMovingAverage as a dict."""
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "num_updates": self.num_updates,
+ "shadow_params": self.shadow_params,
+ "collected_params": self.collected_params
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the ExponentialMovingAverage state.
+
+ Args:
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+ self.decay = state_dict["decay"]
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.num_updates = state_dict["num_updates"]
+ assert self.num_updates is None or isinstance(self.num_updates, int), \
+ "Invalid num_updates"
+
+ self.shadow_params = state_dict["shadow_params"]
+ assert isinstance(self.shadow_params, list), \
+ "shadow_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.shadow_params
+ ), "shadow_params must all be Tensors"
+
+ self.collected_params = state_dict["collected_params"]
+ if self.collected_params is not None:
+ assert isinstance(self.collected_params, list), \
+ "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"
+ assert len(self.collected_params) == len(self.shadow_params), \
+ "collected_params and shadow_params had different lengths"
+
+ if len(self.shadow_params) == len(self._params_refs):
+ # Consistant with torch.optim.Optimizer, cast things to consistant
+ # device and dtype with the parameters
+ params = [p() for p in self._params_refs]
+ # If parameters have been garbage collected, just load the state
+ # we were given without change.
+ if not any(p is None for p in params):
+ # ^ parameter references are still good
+ for i, p in enumerate(params):
+ self.shadow_params[i] = self.shadow_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ if self.collected_params is not None:
+ self.collected_params[i] = self.collected_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ else:
+ raise ValueError(
+ "Tried to `load_state_dict()` with the wrong number of "
+ "parameters in the saved state."
+ )
+
+ def eval(self):
+ if self._is_train_mode:
+ with torch.no_grad():
+ self.store()
+ self.copy_to()
+ self._is_train_mode = False
+
+ def train(self):
+ if not self._is_train_mode:
+ with torch.no_grad():
+ self.restore()
+ self._is_train_mode = True
diff --git a/toolkit/embedding.py b/toolkit/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ba3f2f33bfa023f31da37f12c3ca4a34f0cc21
--- /dev/null
+++ b/toolkit/embedding.py
@@ -0,0 +1,284 @@
+import json
+import os
+from collections import OrderedDict
+
+import safetensors
+import torch
+from typing import TYPE_CHECKING
+
+from safetensors.torch import save_file
+
+from toolkit.metadata import get_meta_for_safetensors
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+ from toolkit.config_modules import EmbeddingConfig
+
+
+# this is a frankenstein mix of automatic1111 and my own code
+
+class Embedding:
+ def __init__(
+ self,
+ sd: 'StableDiffusion',
+ embed_config: 'EmbeddingConfig',
+ state_dict: OrderedDict = None,
+ ):
+ self.name = embed_config.trigger
+ self.sd = sd
+ self.trigger = embed_config.trigger
+ self.embed_config = embed_config
+ self.step = 0
+ # setup our embedding
+ # Add the placeholder token in tokenizer
+ placeholder_tokens = [self.embed_config.trigger]
+
+ # add dummy tokens for multi-vector
+ additional_tokens = []
+ for i in range(1, self.embed_config.tokens):
+ additional_tokens.append(f"{self.embed_config.trigger}_{i}")
+ placeholder_tokens += additional_tokens
+
+ # handle dual tokenizer
+ self.tokenizer_list = self.sd.tokenizer if isinstance(self.sd.tokenizer, list) else [self.sd.tokenizer]
+ self.text_encoder_list = self.sd.text_encoder if isinstance(self.sd.text_encoder, list) else [
+ self.sd.text_encoder]
+
+ self.placeholder_token_ids = []
+ self.embedding_tokens = []
+
+ print(f"Adding {placeholder_tokens} tokens to tokenizer")
+ print(f"Adding {self.embed_config.tokens} tokens to tokenizer")
+
+ for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
+ num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
+ if num_added_tokens != self.embed_config.tokens:
+ raise ValueError(
+ f"The tokenizer already contains the token {self.embed_config.trigger}. Please pass a different"
+ f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ init_token_ids = tokenizer.encode(self.embed_config.init_words, add_special_tokens=False)
+ # if length of token ids is more than number of orm embedding tokens fill with *
+ if len(init_token_ids) > self.embed_config.tokens:
+ init_token_ids = init_token_ids[:self.embed_config.tokens]
+ elif len(init_token_ids) < self.embed_config.tokens:
+ pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
+ init_token_ids += pad_token_id * (self.embed_config.tokens - len(init_token_ids))
+
+ placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
+ self.placeholder_token_ids.append(placeholder_token_ids)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = text_encoder.get_input_embeddings().weight.data
+ with torch.no_grad():
+ for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
+ token_embeds[token_id] = token_embeds[initializer_token_id].clone()
+
+ # replace "[name] with this. on training. This is automatically generated in pipeline on inference
+ self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
+
+ # backup text encoder embeddings
+ self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
+
+ def restore_embeddings(self):
+ with torch.no_grad():
+ # Let's make sure we don't update any embedding weights besides the newly added token
+ for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
+ self.tokenizer_list,
+ self.orig_embeds_params,
+ self.placeholder_token_ids):
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
+ text_encoder.get_input_embeddings().weight[
+ index_no_updates
+ ] = orig_embeds[index_no_updates]
+ weight = text_encoder.get_input_embeddings().weight
+ pass
+
+ def get_trainable_params(self):
+ params = []
+ for text_encoder in self.text_encoder_list:
+ params += text_encoder.get_input_embeddings().parameters()
+ return params
+
+ def _get_vec(self, text_encoder_idx=0):
+ # should we get params instead
+ # create vector from token embeds
+ token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
+ # stack the tokens along batch axis adding that axis
+ new_vector = torch.stack(
+ [token_embeds[token_id] for token_id in self.placeholder_token_ids[text_encoder_idx]],
+ dim=0
+ )
+ return new_vector
+
+ def _set_vec(self, new_vector, text_encoder_idx=0):
+ # shape is (1, 768) for SD 1.5 for 1 token
+ token_embeds = self.text_encoder_list[text_encoder_idx].get_input_embeddings().weight.data
+ for i in range(new_vector.shape[0]):
+ # apply the weights to the placeholder tokens while preserving gradient
+ token_embeds[self.placeholder_token_ids[text_encoder_idx][i]] = new_vector[i].clone()
+
+ # make setter and getter for vec
+ @property
+ def vec(self):
+ return self._get_vec(0)
+
+ @vec.setter
+ def vec(self, new_vector):
+ self._set_vec(new_vector, 0)
+
+ @property
+ def vec2(self):
+ return self._get_vec(1)
+
+ @vec2.setter
+ def vec2(self, new_vector):
+ self._set_vec(new_vector, 1)
+
+ # diffusers automatically expands the token meaning test123 becomes test123 test123_1 test123_2 etc
+ # however, on training we don't use that pipeline, so we have to do it ourselves
+ def inject_embedding_to_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
+ output_prompt = prompt
+ embedding_tokens = self.embedding_tokens[0] # shoudl be the same
+ default_replacements = ["[name]", "[trigger]"]
+
+ replace_with = embedding_tokens if expand_token else self.trigger
+ if to_replace_list is None:
+ to_replace_list = default_replacements
+ else:
+ to_replace_list += default_replacements
+
+ # remove duplicates
+ to_replace_list = list(set(to_replace_list))
+
+ # replace them all
+ for to_replace in to_replace_list:
+ # replace it
+ output_prompt = output_prompt.replace(to_replace, replace_with)
+
+ # see how many times replace_with is in the prompt
+ num_instances = output_prompt.count(replace_with)
+
+ if num_instances == 0 and add_if_not_present:
+ # add it to the beginning of the prompt
+ output_prompt = replace_with + " " + output_prompt
+
+ if num_instances > 1:
+ print(
+ f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
+
+ return output_prompt
+
+ def state_dict(self):
+ if self.sd.is_xl:
+ state_dict = OrderedDict()
+ state_dict['clip_l'] = self.vec
+ state_dict['clip_g'] = self.vec2
+ else:
+ state_dict = OrderedDict()
+ state_dict['emb_params'] = self.vec
+
+ return state_dict
+
+ def save(self, filename):
+ # todo check to see how to get the vector out of the embedding
+
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ # todo get these
+ "sd_checkpoint": None,
+ "sd_checkpoint_name": None,
+ "notes": None,
+ }
+ # TODO we do not currently support this. Check how auto is doing it. Only safetensors supported sor sdxl
+ if filename.endswith('.pt'):
+ torch.save(embedding_data, filename)
+ elif filename.endswith('.bin'):
+ torch.save(embedding_data, filename)
+ elif filename.endswith('.safetensors'):
+ # save the embedding as a safetensors file
+ state_dict = self.state_dict()
+ # add all embedding data (except string_to_param), to metadata
+ metadata = OrderedDict({k: json.dumps(v) for k, v in embedding_data.items() if k != "string_to_param"})
+ metadata["string_to_param"] = {"*": "emb_params"}
+ save_meta = get_meta_for_safetensors(metadata, name=self.name)
+ save_file(state_dict, filename, metadata=save_meta)
+
+ def load_embedding_from_file(self, file_path, device):
+ # full path
+ path = os.path.realpath(file_path)
+ filename = os.path.basename(path)
+ name, ext = os.path.splitext(filename)
+ tensors = {}
+ ext = ext.upper()
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
+ return
+
+ if ext in ['.BIN', '.PT']:
+ # todo check this
+ if self.sd.is_xl:
+ raise Exception("XL not supported yet for bin, pt")
+ data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ # rebuild the embedding from the safetensors file if it has it
+ with safetensors.torch.safe_open(path, framework="pt", device="cpu") as f:
+ metadata = f.metadata()
+ for k in f.keys():
+ tensors[k] = f.get_tensor(k)
+ # data = safetensors.torch.load_file(path, device="cpu")
+ if metadata and 'string_to_param' in metadata and 'emb_params' in tensors:
+ # our format
+ def try_json(v):
+ try:
+ return json.loads(v)
+ except:
+ return v
+
+ data = {k: try_json(v) for k, v in metadata.items()}
+ data['string_to_param'] = {'*': tensors['emb_params']}
+ else:
+ # old format
+ data = tensors
+ else:
+ return
+
+ if self.sd.is_xl:
+ self.vec = tensors['clip_l'].detach().to(device, dtype=torch.float32)
+ self.vec2 = tensors['clip_g'].detach().to(device, dtype=torch.float32)
+ if 'step' in data:
+ self.step = int(data['step'])
+ else:
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict,
+ '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(
+ f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ if 'step' in data:
+ self.step = int(data['step'])
+
+ self.vec = emb.detach().to(device, dtype=torch.float32)
diff --git a/toolkit/esrgan_utils.py b/toolkit/esrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a8bfbada1bff84bc6bb1a49149d846c9c8c379
--- /dev/null
+++ b/toolkit/esrgan_utils.py
@@ -0,0 +1,51 @@
+
+to_basicsr_dict = {
+ 'model.0.weight': 'conv_first.weight',
+ 'model.0.bias': 'conv_first.bias',
+ 'model.1.sub.23.weight': 'conv_body.weight',
+ 'model.1.sub.23.bias': 'conv_body.bias',
+ 'model.3.weight': 'conv_up1.weight',
+ 'model.3.bias': 'conv_up1.bias',
+ 'model.6.weight': 'conv_up2.weight',
+ 'model.6.bias': 'conv_up2.bias',
+ 'model.8.weight': 'conv_hr.weight',
+ 'model.8.bias': 'conv_hr.bias',
+ 'model.10.bias': 'conv_last.bias',
+ 'model.10.weight': 'conv_last.weight',
+ # 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight'
+}
+
+def convert_state_dict_to_basicsr(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k in to_basicsr_dict:
+ new_state_dict[to_basicsr_dict[k]] = v
+ elif k.startswith('model.1.sub.'):
+ bsr_name = k.replace('model.1.sub.', 'body.').lower()
+ bsr_name = bsr_name.replace('.0.weight', '.weight')
+ bsr_name = bsr_name.replace('.0.bias', '.bias')
+ new_state_dict[bsr_name] = v
+ else:
+ new_state_dict[k] = v
+ return new_state_dict
+
+
+# just matching a commonly used format
+def convert_basicsr_state_dict_to_save_format(state_dict):
+ new_state_dict = {}
+ to_basicsr_dict_values = list(to_basicsr_dict.values())
+ for k, v in state_dict.items():
+ if k in to_basicsr_dict_values:
+ for key, value in to_basicsr_dict.items():
+ if value == k:
+ new_state_dict[key] = v
+
+ elif k.startswith('body.'):
+ bsr_name = k.replace('body.', 'model.1.sub.').lower()
+ bsr_name = bsr_name.replace('rdb', 'RDB')
+ bsr_name = bsr_name.replace('.weight', '.0.weight')
+ bsr_name = bsr_name.replace('.bias', '.0.bias')
+ new_state_dict[bsr_name] = v
+ else:
+ new_state_dict[k] = v
+ return new_state_dict
diff --git a/toolkit/extension.py b/toolkit/extension.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d1f38e57c7295546bb621c6f3234346f92f73a1
--- /dev/null
+++ b/toolkit/extension.py
@@ -0,0 +1,57 @@
+import os
+import importlib
+import pkgutil
+from typing import List
+
+from toolkit.paths import TOOLKIT_ROOT
+
+
+class Extension(object):
+ """Base class for extensions.
+
+ Extensions are registered with the ExtensionManager, which is
+ responsible for calling the extension's load() and unload()
+ methods at the appropriate times.
+
+ """
+
+ name: str = None
+ uid: str = None
+
+ @classmethod
+ def get_process(cls):
+ # extend in subclass
+ pass
+
+
+def get_all_extensions() -> List[Extension]:
+ extension_folders = ['extensions', 'extensions_built_in']
+
+ # This will hold the classes from all extension modules
+ all_extension_classes: List[Extension] = []
+
+ # Iterate over all directories (i.e., packages) in the "extensions" directory
+ for sub_dir in extension_folders:
+ extensions_dir = os.path.join(TOOLKIT_ROOT, sub_dir)
+ for (_, name, _) in pkgutil.iter_modules([extensions_dir]):
+ try:
+ # Import the module
+ module = importlib.import_module(f"{sub_dir}.{name}")
+ # Get the value of the AI_TOOLKIT_EXTENSIONS variable
+ extensions = getattr(module, "AI_TOOLKIT_EXTENSIONS", None)
+ # Check if the value is a list
+ if isinstance(extensions, list):
+ # Iterate over the list and add the classes to the main list
+ all_extension_classes.extend(extensions)
+ except ImportError as e:
+ print(f"Failed to import the {name} module. Error: {str(e)}")
+
+ return all_extension_classes
+
+
+def get_all_extensions_process_dict():
+ all_extensions = get_all_extensions()
+ process_dict = {}
+ for extension in all_extensions:
+ process_dict[extension.uid] = extension.get_process()
+ return process_dict
diff --git a/toolkit/guidance.py b/toolkit/guidance.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcf282046c12eba068dac7e918945135b705ef9e
--- /dev/null
+++ b/toolkit/guidance.py
@@ -0,0 +1,693 @@
+import torch
+from typing import Literal, Optional
+
+from toolkit.basic import value_map
+from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
+from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
+from toolkit.stable_diffusion_model import StableDiffusion
+from toolkit.train_tools import get_torch_dtype
+
+GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"]
+
+DIFFERENTIAL_SCALER = 0.2
+
+
+# DIFFERENTIAL_SCALER = 0.25
+
+
+def get_differential_mask(
+ conditional_latents: torch.Tensor,
+ unconditional_latents: torch.Tensor,
+ threshold: float = 0.2,
+ gradient: bool = False,
+):
+ # make a differential mask
+ differential_mask = torch.abs(conditional_latents - unconditional_latents)
+ max_differential = \
+ differential_mask.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
+ differential_scaler = 1.0 / max_differential
+ differential_mask = differential_mask * differential_scaler
+
+ if gradient:
+ # wew need to scale it to 0-1
+ # differential_mask = differential_mask - differential_mask.min()
+ # differential_mask = differential_mask / differential_mask.max()
+ # add 0.2 threshold to both sides and clip
+ differential_mask = value_map(
+ differential_mask,
+ differential_mask.min(),
+ differential_mask.max(),
+ 0 - threshold,
+ 1 + threshold
+ )
+ differential_mask = torch.clamp(differential_mask, 0.0, 1.0)
+ else:
+
+ # make everything less than 0.2 be 0.0 and everything else be 1.0
+ differential_mask = torch.where(
+ differential_mask < threshold,
+ torch.zeros_like(differential_mask),
+ torch.ones_like(differential_mask)
+ )
+ return differential_mask
+
+
+def get_targeted_polarity_loss(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: PromptEmbeds,
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ **kwargs
+):
+ dtype = get_torch_dtype(sd.torch_dtype)
+ device = sd.device_torch
+ with torch.no_grad():
+ conditional_latents = batch.latents.to(device, dtype=dtype).detach()
+ unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
+
+ # inputs_abs_mean = torch.abs(conditional_latents).mean(dim=[1, 2, 3], keepdim=True)
+ # noise_abs_mean = torch.abs(noise).mean(dim=[1, 2, 3], keepdim=True)
+ differential_scaler = DIFFERENTIAL_SCALER
+
+ unconditional_diff = (unconditional_latents - conditional_latents)
+ unconditional_diff_noise = unconditional_diff * differential_scaler
+ conditional_diff = (conditional_latents - unconditional_latents)
+ conditional_diff_noise = conditional_diff * differential_scaler
+ conditional_diff_noise = conditional_diff_noise.detach().requires_grad_(False)
+ unconditional_diff_noise = unconditional_diff_noise.detach().requires_grad_(False)
+ #
+ baseline_conditional_noisy_latents = sd.add_noise(
+ conditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ baseline_unconditional_noisy_latents = sd.add_noise(
+ unconditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ conditional_noise = noise + unconditional_diff_noise
+ unconditional_noise = noise + conditional_diff_noise
+
+ conditional_noisy_latents = sd.add_noise(
+ conditional_latents,
+ conditional_noise,
+ timesteps
+ ).detach()
+
+ unconditional_noisy_latents = sd.add_noise(
+ unconditional_latents,
+ unconditional_noise,
+ timesteps
+ ).detach()
+
+ # double up everything to run it through all at once
+ cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
+ cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
+ cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
+ # cat_baseline_noisy_latents = torch.cat(
+ # [baseline_conditional_noisy_latents, baseline_unconditional_noisy_latents],
+ # dim=0
+ # )
+
+ # Disable the LoRA network so we can predict parent network knowledge without it
+ # sd.network.is_active = False
+ # sd.unet.eval()
+
+ # Predict noise to get a baseline of what the parent network wants to do with the latents + noise.
+ # This acts as our control to preserve the unaltered parts of the image.
+ # baseline_prediction = sd.predict_noise(
+ # latents=cat_baseline_noisy_latents.to(device, dtype=dtype).detach(),
+ # conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
+ # timestep=cat_timesteps,
+ # guidance_scale=1.0,
+ # **pred_kwargs # adapter residuals in here
+ # ).detach()
+
+ # conditional_baseline_prediction, unconditional_baseline_prediction = torch.chunk(baseline_prediction, 2, dim=0)
+
+ # negative_network_weights = [weight * -1.0 for weight in network_weight_list]
+ # positive_network_weights = [weight * 1.0 for weight in network_weight_list]
+ # cat_network_weight_list = positive_network_weights + negative_network_weights
+
+ # turn the LoRA network back on.
+ sd.unet.train()
+ # sd.network.is_active = True
+
+ # sd.network.multiplier = cat_network_weight_list
+
+ # do our prediction with LoRA active on the scaled guidance latents
+ prediction = sd.predict_noise(
+ latents=cat_latents.to(device, dtype=dtype).detach(),
+ conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
+ timestep=cat_timesteps,
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ )
+
+ # prediction = prediction - baseline_prediction
+
+ pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
+ # pred_pos = pred_pos - conditional_baseline_prediction
+ # pred_neg = pred_neg - unconditional_baseline_prediction
+
+ pred_loss = torch.nn.functional.mse_loss(
+ pred_pos.float(),
+ conditional_noise.float(),
+ reduction="none"
+ )
+ pred_loss = pred_loss.mean([1, 2, 3])
+
+ pred_neg_loss = torch.nn.functional.mse_loss(
+ pred_neg.float(),
+ unconditional_noise.float(),
+ reduction="none"
+ )
+ pred_neg_loss = pred_neg_loss.mean([1, 2, 3])
+
+ loss = pred_loss + pred_neg_loss
+
+ loss = loss.mean()
+ loss.backward()
+
+ # detach it so parent class can run backward on no grads without throwing error
+ loss = loss.detach()
+ loss.requires_grad_(True)
+
+ return loss
+
+def get_direct_guidance_loss(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: 'PromptEmbeds',
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ unconditional_embeds: Optional[PromptEmbeds] = None,
+ mask_multiplier=None,
+ prior_pred=None,
+ **kwargs
+):
+ with torch.no_grad():
+ # Perform targeted guidance (working title)
+ dtype = get_torch_dtype(sd.torch_dtype)
+ device = sd.device_torch
+
+
+ conditional_latents = batch.latents.to(device, dtype=dtype).detach()
+ unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
+
+ conditional_noisy_latents = sd.add_noise(
+ conditional_latents,
+ # target_noise,
+ noise,
+ timesteps
+ ).detach()
+
+ unconditional_noisy_latents = sd.add_noise(
+ unconditional_latents,
+ noise,
+ timesteps
+ ).detach()
+ # turn the LoRA network back on.
+ sd.unet.train()
+ # sd.network.is_active = True
+
+ # sd.network.multiplier = network_weight_list
+ # do our prediction with LoRA active on the scaled guidance latents
+ if unconditional_embeds is not None:
+ unconditional_embeds = unconditional_embeds.to(device, dtype=dtype).detach()
+ unconditional_embeds = concat_prompt_embeds([unconditional_embeds, unconditional_embeds])
+
+ prediction = sd.predict_noise(
+ latents=torch.cat([unconditional_noisy_latents, conditional_noisy_latents]).to(device, dtype=dtype).detach(),
+ conditional_embeddings=concat_prompt_embeds([conditional_embeds,conditional_embeds]).to(device, dtype=dtype).detach(),
+ unconditional_embeddings=unconditional_embeds,
+ timestep=torch.cat([timesteps, timesteps]),
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ )
+
+ noise_pred_uncond, noise_pred_cond = torch.chunk(prediction, 2, dim=0)
+
+ guidance_scale = 1.1
+ guidance_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_cond - noise_pred_uncond
+ )
+
+ guidance_loss = torch.nn.functional.mse_loss(
+ guidance_pred.float(),
+ noise.detach().float(),
+ reduction="none"
+ )
+ if mask_multiplier is not None:
+ guidance_loss = guidance_loss * mask_multiplier
+
+ guidance_loss = guidance_loss.mean([1, 2, 3])
+
+ guidance_loss = guidance_loss.mean()
+
+ # loss = guidance_loss + masked_noise_loss
+ loss = guidance_loss
+
+ loss.backward()
+
+ # detach it so parent class can run backward on no grads without throwing error
+ loss = loss.detach()
+ loss.requires_grad_(True)
+
+ return loss
+
+
+# targeted
+def get_targeted_guidance_loss(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: 'PromptEmbeds',
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ **kwargs
+):
+ with torch.no_grad():
+ dtype = get_torch_dtype(sd.torch_dtype)
+ device = sd.device_torch
+
+ conditional_latents = batch.latents.to(device, dtype=dtype).detach()
+ unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
+
+ # Encode the unconditional image into latents
+ unconditional_noisy_latents = sd.noise_scheduler.add_noise(
+ unconditional_latents,
+ noise,
+ timesteps
+ )
+ conditional_noisy_latents = sd.noise_scheduler.add_noise(
+ conditional_latents,
+ noise,
+ timesteps
+ )
+
+ # was_network_active = self.network.is_active
+ sd.network.is_active = False
+ sd.unet.eval()
+
+ target_differential = unconditional_latents - conditional_latents
+ # scale our loss by the differential scaler
+ target_differential_abs = target_differential.abs()
+ target_differential_abs_min = \
+ target_differential_abs.min(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
+ target_differential_abs_max = \
+ target_differential_abs.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
+
+ min_guidance = 1.0
+ max_guidance = 2.0
+
+ differential_scaler = value_map(
+ target_differential_abs,
+ target_differential_abs_min,
+ target_differential_abs_max,
+ min_guidance,
+ max_guidance
+ ).detach()
+
+
+ # With LoRA network bypassed, predict noise to get a baseline of what the network
+ # wants to do with the latents + noise. Pass our target latents here for the input.
+ target_unconditional = sd.predict_noise(
+ latents=unconditional_noisy_latents.to(device, dtype=dtype).detach(),
+ conditional_embeddings=conditional_embeds.to(device, dtype=dtype).detach(),
+ timestep=timesteps,
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ ).detach()
+ prior_prediction_loss = torch.nn.functional.mse_loss(
+ target_unconditional.float(),
+ noise.float(),
+ reduction="none"
+ ).detach().clone()
+
+ # turn the LoRA network back on.
+ sd.unet.train()
+ sd.network.is_active = True
+ sd.network.multiplier = network_weight_list + [x + -1.0 for x in network_weight_list]
+
+ # with LoRA active, predict the noise with the scaled differential latents added. This will allow us
+ # the opportunity to predict the differential + noise that was added to the latents.
+ prediction = sd.predict_noise(
+ latents=torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0).to(device, dtype=dtype).detach(),
+ conditional_embeddings=concat_prompt_embeds([conditional_embeds, conditional_embeds]).to(device, dtype=dtype).detach(),
+ timestep=torch.cat([timesteps, timesteps], dim=0),
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ )
+
+ prediction_conditional, prediction_unconditional = torch.chunk(prediction, 2, dim=0)
+
+ conditional_loss = torch.nn.functional.mse_loss(
+ prediction_conditional.float(),
+ noise.float(),
+ reduction="none"
+ )
+
+ unconditional_loss = torch.nn.functional.mse_loss(
+ prediction_unconditional.float(),
+ noise.float(),
+ reduction="none"
+ )
+
+ positive_loss = torch.abs(
+ conditional_loss.float() - prior_prediction_loss.float(),
+ )
+ # scale our loss by the differential scaler
+ positive_loss = positive_loss * differential_scaler
+
+ positive_loss = positive_loss.mean([1, 2, 3])
+
+ polar_loss = torch.abs(
+ conditional_loss.float() - unconditional_loss.float(),
+ ).mean([1, 2, 3])
+
+
+ positive_loss = positive_loss.mean() + polar_loss.mean()
+
+
+ positive_loss.backward()
+ # loss = positive_loss.detach() + negative_loss.detach()
+ loss = positive_loss.detach()
+
+ # add a grad so other backward does not fail
+ loss.requires_grad_(True)
+
+ # restore network
+ sd.network.multiplier = network_weight_list
+
+ return loss
+
+def get_guided_loss_polarity(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: PromptEmbeds,
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ scaler=None,
+ **kwargs
+):
+ dtype = get_torch_dtype(sd.torch_dtype)
+ device = sd.device_torch
+ with torch.no_grad():
+ dtype = get_torch_dtype(dtype)
+ noise = noise.to(device, dtype=dtype).detach()
+
+ conditional_latents = batch.latents.to(device, dtype=dtype).detach()
+ unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
+
+ target_pos = noise
+ target_neg = noise
+
+ if sd.is_flow_matching:
+ # set the timesteps for flow matching as linear since we will do weighing
+ sd.noise_scheduler.set_train_timesteps(1000, device, linear=True)
+ target_pos = (noise - conditional_latents).detach()
+ target_neg = (noise - unconditional_latents).detach()
+
+ conditional_noisy_latents = sd.add_noise(
+ conditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ unconditional_noisy_latents = sd.add_noise(
+ unconditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ # double up everything to run it through all at once
+ cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
+ cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
+ cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
+
+ negative_network_weights = [weight * -1.0 for weight in network_weight_list]
+ positive_network_weights = [weight * 1.0 for weight in network_weight_list]
+ cat_network_weight_list = positive_network_weights + negative_network_weights
+
+ # turn the LoRA network back on.
+ sd.unet.train()
+ sd.network.is_active = True
+
+ sd.network.multiplier = cat_network_weight_list
+
+ # do our prediction with LoRA active on the scaled guidance latents
+ prediction = sd.predict_noise(
+ latents=cat_latents.to(device, dtype=dtype).detach(),
+ conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
+ timestep=cat_timesteps,
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ )
+
+ pred_pos, pred_neg = torch.chunk(prediction, 2, dim=0)
+
+ pred_loss = torch.nn.functional.mse_loss(
+ pred_pos.float(),
+ target_pos.float(),
+ reduction="none"
+ )
+ # pred_loss = pred_loss.mean([1, 2, 3])
+
+ pred_neg_loss = torch.nn.functional.mse_loss(
+ pred_neg.float(),
+ target_neg.float(),
+ reduction="none"
+ )
+
+ loss = pred_loss + pred_neg_loss
+
+ # if sd.is_flow_matching:
+ # timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach()
+ # loss = loss * timestep_weight
+
+
+ loss = loss.mean([1, 2, 3])
+ loss = loss.mean()
+ if scaler is not None:
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ # detach it so parent class can run backward on no grads without throwing error
+ loss = loss.detach()
+ loss.requires_grad_(True)
+
+ return loss
+
+
+
+def get_guided_tnt(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: PromptEmbeds,
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ prior_pred: torch.Tensor = None,
+ **kwargs
+):
+ dtype = get_torch_dtype(sd.torch_dtype)
+ device = sd.device_torch
+ with torch.no_grad():
+ dtype = get_torch_dtype(dtype)
+ noise = noise.to(device, dtype=dtype).detach()
+
+ conditional_latents = batch.latents.to(device, dtype=dtype).detach()
+ unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach()
+
+ conditional_noisy_latents = sd.add_noise(
+ conditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ unconditional_noisy_latents = sd.add_noise(
+ unconditional_latents,
+ noise,
+ timesteps
+ ).detach()
+
+ # double up everything to run it through all at once
+ cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
+ cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0)
+ cat_timesteps = torch.cat([timesteps, timesteps], dim=0)
+
+
+ # turn the LoRA network back on.
+ sd.unet.train()
+ if sd.network is not None:
+ cat_network_weight_list = [weight for weight in network_weight_list * 2]
+ sd.network.multiplier = cat_network_weight_list
+ sd.network.is_active = True
+
+
+ prediction = sd.predict_noise(
+ latents=cat_latents.to(device, dtype=dtype).detach(),
+ conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(),
+ timestep=cat_timesteps,
+ guidance_scale=1.0,
+ **pred_kwargs # adapter residuals in here
+ )
+ this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0)
+
+ this_loss = torch.nn.functional.mse_loss(
+ this_prediction.float(),
+ noise.float(),
+ reduction="none"
+ )
+
+ that_loss = torch.nn.functional.mse_loss(
+ that_prediction.float(),
+ noise.float(),
+ reduction="none"
+ )
+
+ this_loss = this_loss.mean([1, 2, 3])
+ # negative loss on that
+ that_loss = -that_loss.mean([1, 2, 3])
+
+ with torch.no_grad():
+ # match that loss with this loss so it is not a negative value and same scale
+ that_loss_scaler = torch.abs(this_loss) / torch.abs(that_loss)
+
+ that_loss = that_loss * that_loss_scaler * 0.01
+
+ loss = this_loss + that_loss
+
+ loss = loss.mean()
+
+ loss.backward()
+
+ # detach it so parent class can run backward on no grads without throwing error
+ loss = loss.detach()
+ loss.requires_grad_(True)
+
+ return loss
+
+
+
+# this processes all guidance losses based on the batch information
+def get_guidance_loss(
+ noisy_latents: torch.Tensor,
+ conditional_embeds: 'PromptEmbeds',
+ match_adapter_assist: bool,
+ network_weight_list: list,
+ timesteps: torch.Tensor,
+ pred_kwargs: dict,
+ batch: 'DataLoaderBatchDTO',
+ noise: torch.Tensor,
+ sd: 'StableDiffusion',
+ unconditional_embeds: Optional[PromptEmbeds] = None,
+ mask_multiplier=None,
+ prior_pred=None,
+ scaler=None,
+ **kwargs
+):
+ # TODO add others and process individual batch items separately
+ guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type
+
+ if guidance_type == "targeted":
+ assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted guidance"
+ return get_targeted_guidance_loss(
+ noisy_latents,
+ conditional_embeds,
+ match_adapter_assist,
+ network_weight_list,
+ timesteps,
+ pred_kwargs,
+ batch,
+ noise,
+ sd,
+ **kwargs
+ )
+ elif guidance_type == "polarity":
+ assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
+ return get_guided_loss_polarity(
+ noisy_latents,
+ conditional_embeds,
+ match_adapter_assist,
+ network_weight_list,
+ timesteps,
+ pred_kwargs,
+ batch,
+ noise,
+ sd,
+ scaler=scaler,
+ **kwargs
+ )
+ elif guidance_type == "tnt":
+ assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance"
+ return get_guided_tnt(
+ noisy_latents,
+ conditional_embeds,
+ match_adapter_assist,
+ network_weight_list,
+ timesteps,
+ pred_kwargs,
+ batch,
+ noise,
+ sd,
+ prior_pred=prior_pred,
+ **kwargs
+ )
+
+ elif guidance_type == "targeted_polarity":
+ assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance"
+ return get_targeted_polarity_loss(
+ noisy_latents,
+ conditional_embeds,
+ match_adapter_assist,
+ network_weight_list,
+ timesteps,
+ pred_kwargs,
+ batch,
+ noise,
+ sd,
+ **kwargs
+ )
+ elif guidance_type == "direct":
+ return get_direct_guidance_loss(
+ noisy_latents,
+ conditional_embeds,
+ match_adapter_assist,
+ network_weight_list,
+ timesteps,
+ pred_kwargs,
+ batch,
+ noise,
+ sd,
+ unconditional_embeds=unconditional_embeds,
+ mask_multiplier=mask_multiplier,
+ prior_pred=prior_pred,
+ **kwargs
+ )
+ else:
+ raise NotImplementedError(f"Guidance type {guidance_type} is not implemented")
diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9f306e077ee50594f6ca6a4a005025434dcdf7
--- /dev/null
+++ b/toolkit/image_utils.py
@@ -0,0 +1,516 @@
+# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
+import atexit
+import collections
+import json
+import os
+import io
+import struct
+import threading
+from typing import TYPE_CHECKING
+
+import cv2
+import numpy as np
+import torch
+from diffusers import AutoencoderTiny
+
+FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
+
+
+class UnknownImageFormat(Exception):
+ pass
+
+
+types = collections.OrderedDict()
+BMP = types['BMP'] = 'BMP'
+GIF = types['GIF'] = 'GIF'
+ICO = types['ICO'] = 'ICO'
+JPEG = types['JPEG'] = 'JPEG'
+PNG = types['PNG'] = 'PNG'
+TIFF = types['TIFF'] = 'TIFF'
+
+image_fields = ['path', 'type', 'file_size', 'width', 'height']
+
+
+class Image(collections.namedtuple('Image', image_fields)):
+
+ def to_str_row(self):
+ return ("%d\t%d\t%d\t%s\t%s" % (
+ self.width,
+ self.height,
+ self.file_size,
+ self.type,
+ self.path.replace('\t', '\\t'),
+ ))
+
+ def to_str_row_verbose(self):
+ return ("%d\t%d\t%d\t%s\t%s\t##%s" % (
+ self.width,
+ self.height,
+ self.file_size,
+ self.type,
+ self.path.replace('\t', '\\t'),
+ self))
+
+ def to_str_json(self, indent=None):
+ return json.dumps(self._asdict(), indent=indent)
+
+
+def get_image_size(file_path):
+ """
+ Return (width, height) for a given img file content - no external
+ dependencies except the os and struct builtin modules
+ """
+ img = get_image_metadata(file_path)
+ return (img.width, img.height)
+
+
+def get_image_size_from_bytesio(input, size):
+ """
+ Return (width, height) for a given img file content - no external
+ dependencies except the os and struct builtin modules
+
+ Args:
+ input (io.IOBase): io object support read & seek
+ size (int): size of buffer in byte
+ """
+ img = get_image_metadata_from_bytesio(input, size)
+ return (img.width, img.height)
+
+
+def get_image_metadata(file_path):
+ """
+ Return an `Image` object for a given img file content - no external
+ dependencies except the os and struct builtin modules
+
+ Args:
+ file_path (str): path to an image file
+
+ Returns:
+ Image: (path, type, file_size, width, height)
+ """
+ size = os.path.getsize(file_path)
+
+ # be explicit with open arguments - we need binary mode
+ with io.open(file_path, "rb") as input:
+ return get_image_metadata_from_bytesio(input, size, file_path)
+
+
+def get_image_metadata_from_bytesio(input, size, file_path=None):
+ """
+ Return an `Image` object for a given img file content - no external
+ dependencies except the os and struct builtin modules
+
+ Args:
+ input (io.IOBase): io object support read & seek
+ size (int): size of buffer in byte
+ file_path (str): path to an image file
+
+ Returns:
+ Image: (path, type, file_size, width, height)
+ """
+ height = -1
+ width = -1
+ data = input.read(26)
+ msg = " raised while trying to decode as JPEG."
+
+ if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'):
+ # GIFs
+ imgtype = GIF
+ w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n')
+ and (data[12:16] == b'IHDR')):
+ # PNGs
+ imgtype = PNG
+ w, h = struct.unpack(">LL", data[16:24])
+ width = int(w)
+ height = int(h)
+ elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'):
+ # older PNGs
+ imgtype = PNG
+ w, h = struct.unpack(">LL", data[8:16])
+ width = int(w)
+ height = int(h)
+ elif (size >= 2) and data.startswith(b'\377\330'):
+ # JPEG
+ imgtype = JPEG
+ input.seek(0)
+ input.read(2)
+ b = input.read(1)
+ try:
+ while (b and ord(b) != 0xDA):
+ while (ord(b) != 0xFF):
+ b = input.read(1)
+ while (ord(b) == 0xFF):
+ b = input.read(1)
+ if (ord(b) >= 0xC0 and ord(b) <= 0xC3):
+ input.read(3)
+ h, w = struct.unpack(">HH", input.read(4))
+ break
+ else:
+ input.read(
+ int(struct.unpack(">H", input.read(2))[0]) - 2)
+ b = input.read(1)
+ width = int(w)
+ height = int(h)
+ except struct.error:
+ raise UnknownImageFormat("StructError" + msg)
+ except ValueError:
+ raise UnknownImageFormat("ValueError" + msg)
+ except Exception as e:
+ raise UnknownImageFormat(e.__class__.__name__ + msg)
+ elif (size >= 26) and data.startswith(b'BM'):
+ # BMP
+ imgtype = 'BMP'
+ headersize = struct.unpack("= 40:
+ w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"):
+ # Standard TIFF, big- or little-endian
+ # BigTIFF and other different but TIFF-like formats are not
+ # supported currently
+ imgtype = TIFF
+ byteOrder = data[:2]
+ boChar = ">" if byteOrder == "MM" else "<"
+ # maps TIFF type id to size (in bytes)
+ # and python format char for struct
+ tiffTypes = {
+ 1: (1, boChar + "B"), # BYTE
+ 2: (1, boChar + "c"), # ASCII
+ 3: (2, boChar + "H"), # SHORT
+ 4: (4, boChar + "L"), # LONG
+ 5: (8, boChar + "LL"), # RATIONAL
+ 6: (1, boChar + "b"), # SBYTE
+ 7: (1, boChar + "c"), # UNDEFINED
+ 8: (2, boChar + "h"), # SSHORT
+ 9: (4, boChar + "l"), # SLONG
+ 10: (8, boChar + "ll"), # SRATIONAL
+ 11: (4, boChar + "f"), # FLOAT
+ 12: (8, boChar + "d") # DOUBLE
+ }
+ ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
+ try:
+ countSize = 2
+ input.seek(ifdOffset)
+ ec = input.read(countSize)
+ ifdEntryCount = struct.unpack(boChar + "H", ec)[0]
+ # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4
+ # bytes: value offset
+ ifdEntrySize = 12
+ for i in range(ifdEntryCount):
+ entryOffset = ifdOffset + countSize + i * ifdEntrySize
+ input.seek(entryOffset)
+ tag = input.read(2)
+ tag = struct.unpack(boChar + "H", tag)[0]
+ if (tag == 256 or tag == 257):
+ # if type indicates that value fits into 4 bytes, value
+ # offset is not an offset but value itself
+ type = input.read(2)
+ type = struct.unpack(boChar + "H", type)[0]
+ if type not in tiffTypes:
+ raise UnknownImageFormat(
+ "Unkown TIFF field type:" +
+ str(type))
+ typeSize = tiffTypes[type][0]
+ typeChar = tiffTypes[type][1]
+ input.seek(entryOffset + 8)
+ value = input.read(typeSize)
+ value = int(struct.unpack(typeChar, value)[0])
+ if tag == 256:
+ width = value
+ else:
+ height = value
+ if width > -1 and height > -1:
+ break
+ except Exception as e:
+ raise UnknownImageFormat(str(e))
+ elif size >= 2:
+ # see http://en.wikipedia.org/wiki/ICO_(file_format)
+ imgtype = 'ICO'
+ input.seek(0)
+ reserved = input.read(2)
+ if 0 != struct.unpack(" 1:
+ import warnings
+ warnings.warn("ICO File contains more than one image")
+ # http://msdn.microsoft.com/en-us/library/ms997538.aspx
+ w = input.read(1)
+ h = input.read(1)
+ width = ord(w)
+ height = ord(h)
+ else:
+ raise UnknownImageFormat(FILE_UNKNOWN)
+
+ return Image(path=file_path,
+ type=imgtype,
+ file_size=size,
+ width=width,
+ height=height)
+
+
+import unittest
+
+
+class Test_get_image_size(unittest.TestCase):
+ data = [{
+ 'path': 'lookmanodeps.png',
+ 'width': 251,
+ 'height': 208,
+ 'file_size': 22228,
+ 'type': 'PNG'}]
+
+ def setUp(self):
+ pass
+
+ def test_get_image_size_from_bytesio(self):
+ img = self.data[0]
+ p = img['path']
+ with io.open(p, 'rb') as fp:
+ b = fp.read()
+ fp = io.BytesIO(b)
+ sz = len(b)
+ output = get_image_size_from_bytesio(fp, sz)
+ self.assertTrue(output)
+ self.assertEqual(output,
+ (img['width'],
+ img['height']))
+
+ def test_get_image_metadata_from_bytesio(self):
+ img = self.data[0]
+ p = img['path']
+ with io.open(p, 'rb') as fp:
+ b = fp.read()
+ fp = io.BytesIO(b)
+ sz = len(b)
+ output = get_image_metadata_from_bytesio(fp, sz)
+ self.assertTrue(output)
+ for field in image_fields:
+ self.assertEqual(getattr(output, field), None if field == 'path' else img[field])
+
+ def test_get_image_metadata(self):
+ img = self.data[0]
+ output = get_image_metadata(img['path'])
+ self.assertTrue(output)
+ for field in image_fields:
+ self.assertEqual(getattr(output, field), img[field])
+
+ def test_get_image_metadata__ENOENT_OSError(self):
+ with self.assertRaises(OSError):
+ get_image_metadata('THIS_DOES_NOT_EXIST')
+
+ def test_get_image_metadata__not_an_image_UnknownImageFormat(self):
+ with self.assertRaises(UnknownImageFormat):
+ get_image_metadata('README.rst')
+
+ def test_get_image_size(self):
+ img = self.data[0]
+ output = get_image_size(img['path'])
+ self.assertTrue(output)
+ self.assertEqual(output,
+ (img['width'],
+ img['height']))
+
+ def tearDown(self):
+ pass
+
+
+def main(argv=None):
+ """
+ Print image metadata fields for the given file path.
+
+ Keyword Arguments:
+ argv (list): commandline arguments (e.g. sys.argv[1:])
+ Returns:
+ int: zero for OK
+ """
+ import logging
+ import optparse
+ import sys
+
+ prs = optparse.OptionParser(
+ usage="%prog [-v|--verbose] [--json|--json-indent] []",
+ description="Print metadata for the given image paths "
+ "(without image library bindings).")
+
+ prs.add_option('--json',
+ dest='json',
+ action='store_true')
+ prs.add_option('--json-indent',
+ dest='json_indent',
+ action='store_true')
+
+ prs.add_option('-v', '--verbose',
+ dest='verbose',
+ action='store_true', )
+ prs.add_option('-q', '--quiet',
+ dest='quiet',
+ action='store_true', )
+ prs.add_option('-t', '--test',
+ dest='run_tests',
+ action='store_true', )
+
+ argv = list(argv) if argv is not None else sys.argv[1:]
+ (opts, args) = prs.parse_args(args=argv)
+ loglevel = logging.INFO
+ if opts.verbose:
+ loglevel = logging.DEBUG
+ elif opts.quiet:
+ loglevel = logging.ERROR
+ logging.basicConfig(level=loglevel)
+ log = logging.getLogger()
+ log.debug('argv: %r', argv)
+ log.debug('opts: %r', opts)
+ log.debug('args: %r', args)
+
+ if opts.run_tests:
+ import sys
+ sys.argv = [sys.argv[0]] + args
+ import unittest
+ return unittest.main()
+
+ output_func = Image.to_str_row
+ if opts.json_indent:
+ import functools
+ output_func = functools.partial(Image.to_str_json, indent=2)
+ elif opts.json:
+ output_func = Image.to_str_json
+ elif opts.verbose:
+ output_func = Image.to_str_row_verbose
+
+ EX_OK = 0
+ EX_NOT_OK = 2
+
+ if len(args) < 1:
+ prs.print_help()
+ print('')
+ prs.error("You must specify one or more paths to image files")
+
+ errors = []
+ for path_arg in args:
+ try:
+ img = get_image_metadata(path_arg)
+ print(output_func(img))
+ except KeyboardInterrupt:
+ raise
+ except OSError as e:
+ log.error((path_arg, e))
+ errors.append((path_arg, e))
+ except Exception as e:
+ log.exception(e)
+ errors.append((path_arg, e))
+ pass
+ if len(errors):
+ import pprint
+ print("ERRORS", file=sys.stderr)
+ print("======", file=sys.stderr)
+ print(pprint.pformat(errors, indent=2), file=sys.stderr)
+ return EX_NOT_OK
+ return EX_OK
+
+
+is_window_shown = False
+display_lock = threading.Lock()
+current_img = None
+update_event = threading.Event()
+
+def update_image(img, name):
+ global current_img
+ with display_lock:
+ current_img = (img, name)
+ update_event.set()
+
+def display_image_in_thread():
+ global is_window_shown
+
+ def display_img():
+ global current_img
+ while True:
+ update_event.wait()
+ with display_lock:
+ if current_img:
+ img, name = current_img
+ cv2.imshow(name, img)
+ current_img = None
+ update_event.clear()
+ if cv2.waitKey(1) & 0xFF == 27: # Esc key to stop
+ cv2.destroyAllWindows()
+ print('\nESC pressed, stopping')
+ break
+
+ if not is_window_shown:
+ is_window_shown = True
+ threading.Thread(target=display_img, daemon=True).start()
+
+
+def show_img(img, name='AI Toolkit'):
+ img = np.clip(img, 0, 255).astype(np.uint8)
+ update_image(img[:, :, ::-1], name)
+ if not is_window_shown:
+ display_image_in_thread()
+
+
+def show_tensors(imgs: torch.Tensor, name='AI Toolkit'):
+ if len(imgs.shape) == 4:
+ img_list = torch.chunk(imgs, imgs.shape[0], dim=0)
+ else:
+ img_list = [imgs]
+
+ img = torch.cat(img_list, dim=3)
+ img = img / 2 + 0.5
+ img_numpy = img.to(torch.float32).detach().cpu().numpy()
+ img_numpy = np.clip(img_numpy, 0, 1) * 255
+ img_numpy = img_numpy.transpose(0, 2, 3, 1)
+ img_numpy = img_numpy.astype(np.uint8)
+
+ show_img(img_numpy[0], name=name)
+
+
+def show_latents(latents: torch.Tensor, vae: 'AutoencoderTiny', name='AI Toolkit'):
+ if vae.device == 'cpu':
+ vae.to(latents.device)
+ latents = latents / vae.config['scaling_factor']
+ imgs = vae.decode(latents).sample
+ show_tensors(imgs, name=name)
+
+
+def on_exit():
+ if is_window_shown:
+ cv2.destroyAllWindows()
+
+
+def reduce_contrast(tensor, factor):
+ # Ensure factor is between 0 and 1
+ factor = max(0, min(factor, 1))
+
+ # Calculate the mean of the tensor
+ mean = torch.mean(tensor)
+
+ # Reduce contrast
+ adjusted_tensor = (tensor - mean) * factor + mean
+
+ # Clip values to ensure they stay within -1 to 1 range
+ return torch.clamp(adjusted_tensor, -1.0, 1.0)
+
+atexit.register(on_exit)
+
+if __name__ == "__main__":
+ import sys
+
+ sys.exit(main(argv=sys.argv[1:]))
diff --git a/toolkit/inversion_utils.py b/toolkit/inversion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a61d83bd8d62efbfa2c3b99069f7c6bb0f81ca
--- /dev/null
+++ b/toolkit/inversion_utils.py
@@ -0,0 +1,410 @@
+# ref https://huggingface.co/spaces/editing-images/ledits/blob/main/inversion_utils.py
+
+import torch
+import os
+from tqdm import tqdm
+
+from toolkit import train_tools
+from toolkit.prompt_utils import PromptEmbeds
+from toolkit.stable_diffusion_model import StableDiffusion
+
+
+def mu_tilde(model, xt, x0, timestep):
+ "mu_tilde(x_t, x_0) DDPM paper eq. 7"
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
+ alpha_t = model.scheduler.alphas[timestep]
+ beta_t = 1 - alpha_t
+ alpha_bar = model.scheduler.alphas_cumprod[timestep]
+ return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1 - alpha_bar)) * x0 + (
+ (alpha_t ** 0.5 * (1 - alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
+
+
+def sample_xts_from_x0(sd: StableDiffusion, sample: torch.Tensor, num_inference_steps=50):
+ """
+ Samples from P(x_1:T|x_0)
+ """
+ # torch.manual_seed(43256465436)
+ alpha_bar = sd.noise_scheduler.alphas_cumprod
+ sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
+ alphas = sd.noise_scheduler.alphas
+ betas = 1 - alphas
+ # variance_noise_shape = (
+ # num_inference_steps,
+ # sd.unet.in_channels,
+ # sd.unet.sample_size,
+ # sd.unet.sample_size)
+ variance_noise_shape = list(sample.shape)
+ variance_noise_shape[0] = num_inference_steps
+
+ timesteps = sd.noise_scheduler.timesteps.to(sd.device)
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
+ xts = torch.zeros(variance_noise_shape).to(sample.device, dtype=torch.float16)
+ for t in reversed(timesteps):
+ idx = t_to_idx[int(t)]
+ xts[idx] = sample * (alpha_bar[t] ** 0.5) + torch.randn_like(sample, dtype=torch.float16) * sqrt_one_minus_alpha_bar[t]
+ xts = torch.cat([xts, sample], dim=0)
+
+ return xts
+
+
+def encode_text(model, prompts):
+ text_input = model.tokenizer(
+ prompts,
+ padding="max_length",
+ max_length=model.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
+ return text_encoding
+
+
+def forward_step(sd: StableDiffusion, model_output, timestep, sample):
+ next_timestep = min(
+ sd.noise_scheduler.config['num_train_timesteps'] - 2,
+ timestep + sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
+ )
+
+ # 2. compute alphas, betas
+ alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
+ # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 5. TODO: simple noising implementation
+ next_sample = sd.noise_scheduler.add_noise(
+ pred_original_sample,
+ model_output,
+ torch.LongTensor([next_timestep]))
+ return next_sample
+
+
+def get_variance(sd: StableDiffusion, timestep): # , prev_timestep):
+ prev_timestep = timestep - sd.noise_scheduler.config['num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
+ alpha_prod_t = sd.noise_scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ return variance
+
+
+def get_time_ids_from_latents(sd: StableDiffusion, latents: torch.Tensor):
+ VAE_SCALE_FACTOR = 2 ** (len(sd.vae.config['block_out_channels']) - 1)
+ if sd.is_xl:
+ bs, ch, h, w = list(latents.shape)
+
+ height = h * VAE_SCALE_FACTOR
+ width = w * VAE_SCALE_FACTOR
+
+ dtype = latents.dtype
+ # just do it without any cropping nonsense
+ target_size = (height, width)
+ original_size = (height, width)
+ crops_coords_top_left = (0, 0)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
+
+ batch_time_ids = torch.cat(
+ [add_time_ids for _ in range(bs)]
+ )
+ return batch_time_ids
+ else:
+ return None
+
+
+def inversion_forward_process(
+ sd: StableDiffusion,
+ sample: torch.Tensor,
+ conditional_embeddings: PromptEmbeds,
+ unconditional_embeddings: PromptEmbeds,
+ etas=None,
+ prog_bar=False,
+ cfg_scale=3.5,
+ num_inference_steps=50, eps=None
+):
+ current_num_timesteps = len(sd.noise_scheduler.timesteps)
+ sd.noise_scheduler.set_timesteps(num_inference_steps, device=sd.device)
+
+ timesteps = sd.noise_scheduler.timesteps.to(sd.device)
+ # variance_noise_shape = (
+ # num_inference_steps,
+ # sd.unet.in_channels,
+ # sd.unet.sample_size,
+ # sd.unet.sample_size
+ # )
+ variance_noise_shape = list(sample.shape)
+ variance_noise_shape[0] = num_inference_steps
+ if etas is None or (type(etas) in [int, float] and etas == 0):
+ eta_is_zero = True
+ zs = None
+ else:
+ eta_is_zero = False
+ if type(etas) in [int, float]: etas = [etas] * sd.noise_scheduler.num_inference_steps
+ xts = sample_xts_from_x0(sd, sample, num_inference_steps=num_inference_steps)
+ alpha_bar = sd.noise_scheduler.alphas_cumprod
+ zs = torch.zeros(size=variance_noise_shape, device=sd.device, dtype=torch.float16)
+
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
+ noisy_sample = sample
+ op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
+
+ for timestep in op:
+ idx = t_to_idx[int(timestep)]
+ # 1. predict noise residual
+ if not eta_is_zero:
+ noisy_sample = xts[idx][None]
+
+ added_cond_kwargs = {}
+
+ with torch.no_grad():
+ text_embeddings = train_tools.concat_prompt_embeddings(
+ unconditional_embeddings, # negative embedding
+ conditional_embeddings, # positive embedding
+ 1, # batch size
+ )
+ if sd.is_xl:
+ add_time_ids = get_time_ids_from_latents(sd, noisy_sample)
+ # add extra for cfg
+ add_time_ids = torch.cat(
+ [add_time_ids] * 2, dim=0
+ )
+
+ added_cond_kwargs = {
+ "text_embeds": text_embeddings.pooled_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ # double up for cfg
+ latent_model_input = torch.cat(
+ [noisy_sample] * 2, dim=0
+ )
+
+ noise_pred = sd.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=text_embeddings.text_embeds,
+ added_cond_kwargs=added_cond_kwargs,
+ ).sample
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+
+ # out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=uncond_embedding)
+ # cond_out = sd.unet.forward(noisy_sample, timestep=timestep, encoder_hidden_states=text_embeddings)
+
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
+
+ if eta_is_zero:
+ # 2. compute more noisy image and set x_t -> x_t+1
+ noisy_sample = forward_step(sd, noise_pred, timestep, noisy_sample)
+ xts = None
+
+ else:
+ xtm1 = xts[idx + 1][None]
+ # pred of x0
+ pred_original_sample = (noisy_sample - (1 - alpha_bar[timestep]) ** 0.5 * noise_pred) / alpha_bar[
+ timestep] ** 0.5
+
+ # direction to xt
+ prev_timestep = timestep - sd.noise_scheduler.config[
+ 'num_train_timesteps'] // sd.noise_scheduler.num_inference_steps
+ alpha_prod_t_prev = sd.noise_scheduler.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else sd.noise_scheduler.final_alpha_cumprod
+
+ variance = get_variance(sd, timestep)
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
+
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
+ zs[idx] = z
+
+ # correction to avoid error accumulation
+ xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
+ xts[idx + 1] = xtm1
+
+ if not zs is None:
+ zs[-1] = torch.zeros_like(zs[-1])
+
+ # restore timesteps
+ sd.noise_scheduler.set_timesteps(current_num_timesteps, device=sd.device)
+
+ return noisy_sample, zs, xts
+
+
+#
+# def inversion_forward_process(
+# model,
+# sample,
+# etas=None,
+# prog_bar=False,
+# prompt="",
+# cfg_scale=3.5,
+# num_inference_steps=50, eps=None
+# ):
+# if not prompt == "":
+# text_embeddings = encode_text(model, prompt)
+# uncond_embedding = encode_text(model, "")
+# timesteps = model.scheduler.timesteps.to(model.device)
+# variance_noise_shape = (
+# num_inference_steps,
+# model.unet.in_channels,
+# model.unet.sample_size,
+# model.unet.sample_size)
+# if etas is None or (type(etas) in [int, float] and etas == 0):
+# eta_is_zero = True
+# zs = None
+# else:
+# eta_is_zero = False
+# if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
+# xts = sample_xts_from_x0(model, sample, num_inference_steps=num_inference_steps)
+# alpha_bar = model.scheduler.alphas_cumprod
+# zs = torch.zeros(size=variance_noise_shape, device=model.device, dtype=torch.float16)
+#
+# t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
+# noisy_sample = sample
+# op = tqdm(reversed(timesteps), desc="Inverting...") if prog_bar else reversed(timesteps)
+#
+# for t in op:
+# idx = t_to_idx[int(t)]
+# # 1. predict noise residual
+# if not eta_is_zero:
+# noisy_sample = xts[idx][None]
+#
+# with torch.no_grad():
+# out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=uncond_embedding)
+# if not prompt == "":
+# cond_out = model.unet.forward(noisy_sample, timestep=t, encoder_hidden_states=text_embeddings)
+#
+# if not prompt == "":
+# ## classifier free guidance
+# noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
+# else:
+# noise_pred = out.sample
+#
+# if eta_is_zero:
+# # 2. compute more noisy image and set x_t -> x_t+1
+# noisy_sample = forward_step(model, noise_pred, t, noisy_sample)
+#
+# else:
+# xtm1 = xts[idx + 1][None]
+# # pred of x0
+# pred_original_sample = (noisy_sample - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
+#
+# # direction to xt
+# prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
+# alpha_prod_t_prev = model.scheduler.alphas_cumprod[
+# prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
+#
+# variance = get_variance(model, t)
+# pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
+#
+# mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+#
+# z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
+# zs[idx] = z
+#
+# # correction to avoid error accumulation
+# xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
+# xts[idx + 1] = xtm1
+#
+# if not zs is None:
+# zs[-1] = torch.zeros_like(zs[-1])
+#
+# return noisy_sample, zs, xts
+
+
+def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
+ # 2. compute alphas, betas
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[
+ prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
+ variance = get_variance(model, timestep) # , prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+ # Take care of asymetric reverse process (asyrp)
+ model_output_direction = model_output
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+ # 8. Add noice if eta > 0
+ if eta > 0:
+ if variance_noise is None:
+ variance_noise = torch.randn(model_output.shape, device=model.device, dtype=torch.float16)
+ sigma_z = eta * variance ** (0.5) * variance_noise
+ prev_sample = prev_sample + sigma_z
+
+ return prev_sample
+
+
+def inversion_reverse_process(
+ model,
+ xT,
+ etas=0,
+ prompts="",
+ cfg_scales=None,
+ prog_bar=False,
+ zs=None,
+ controller=None,
+ asyrp=False):
+ batch_size = len(prompts)
+
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1, 1, 1, 1).to(model.device, dtype=torch.float16)
+
+ text_embeddings = encode_text(model, prompts)
+ uncond_embedding = encode_text(model, [""] * batch_size)
+
+ if etas is None: etas = 0
+ if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
+ assert len(etas) == model.scheduler.num_inference_steps
+ timesteps = model.scheduler.timesteps.to(model.device)
+
+ xt = xT.expand(batch_size, -1, -1, -1)
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
+
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
+
+ for t in op:
+ idx = t_to_idx[int(t)]
+ ## Unconditional embedding
+ with torch.no_grad():
+ uncond_out = model.unet.forward(xt, timestep=t,
+ encoder_hidden_states=uncond_embedding)
+
+ ## Conditional embedding
+ if prompts:
+ with torch.no_grad():
+ cond_out = model.unet.forward(xt, timestep=t,
+ encoder_hidden_states=text_embeddings)
+
+ z = zs[idx] if not zs is None else None
+ z = z.expand(batch_size, -1, -1, -1)
+ if prompts:
+ ## classifier free guidance
+ noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
+ else:
+ noise_pred = uncond_out.sample
+ # 2. compute less noisy image and set x_t -> x_t-1
+ xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
+ if controller is not None:
+ xt = controller.step_callback(xt)
+ return xt, zs
diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4821e968728ee1102c826d06de5022713ca82256
--- /dev/null
+++ b/toolkit/ip_adapter.py
@@ -0,0 +1,1337 @@
+import random
+
+import torch
+import sys
+
+from PIL import Image
+from diffusers import Transformer2DModel
+from torch import nn
+from torch.nn import Parameter
+from torch.nn.modules.module import T
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
+from toolkit.models.zipper_resampler import ZipperResampler
+from toolkit.paths import REPOS_ROOT
+from toolkit.saving import load_ip_adapter_model
+from toolkit.train_tools import get_torch_dtype
+from toolkit.util.inverse_cfg import inverse_classifier_guidance
+
+sys.path.append(REPOS_ROOT)
+from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
+from collections import OrderedDict
+from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
+ AttnProcessor2_0
+from ipadapter.ip_adapter.ip_adapter import ImageProjModel
+from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
+from toolkit.config_modules import AdapterConfig
+from toolkit.prompt_utils import PromptEmbeds
+import weakref
+from diffusers import FluxTransformer2DModel
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ CLIPVisionModel,
+ AutoImageProcessor,
+ ConvNextModel,
+ ConvNextV2ForImageClassification,
+ ConvNextForImageClassification,
+ ConvNextImageProcessor
+)
+from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
+
+from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
+
+from transformers import ViTFeatureExtractor, ViTForImageClassification
+
+# gradient checkpointing
+from torch.utils.checkpoint import checkpoint
+
+import torch.nn.functional as F
+
+
+class MLPProjModelClipFace(torch.nn.Module):
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.num_tokens = num_tokens
+ self.norm = torch.nn.LayerNorm(id_embeddings_dim)
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
+ torch.nn.GELU(),
+ torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
+ )
+ # Initialize the last linear layer weights near zero
+ torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
+ torch.nn.init.zeros_(self.proj[2].bias)
+ # # Custom initialization for LayerNorm to output near zero
+ # torch.nn.init.constant_(self.norm.weight, 0.1) # Small weights near zero
+ # torch.nn.init.zeros_(self.norm.bias) # Bias to zero
+
+ def forward(self, x):
+ x = self.norm(x)
+ x = self.proj(x)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return x
+
+
+class CustomIPAttentionProcessor(IPAttnProcessor2_0):
+ def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, full_token_scaler=False):
+ super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.train_scaler = train_scaler
+ if train_scaler:
+ if full_token_scaler:
+ self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
+ else:
+ self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
+ # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
+ self.ip_scaler.requires_grad_(True)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ is_active = self.adapter_ref().is_active
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if is_active:
+ # since we are removing tokens, we need to adjust the sequence length
+ sequence_length = sequence_length - self.num_tokens
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ # will be none if disabled
+ if not is_active:
+ ip_hidden_states = None
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ try:
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ except Exception as e:
+ print(e)
+ raise e
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # will be none if disabled
+ if ip_hidden_states is not None:
+ # apply scaler
+ if self.train_scaler:
+ weight = self.ip_scaler
+ # reshape to (1, self.num_tokens, 1)
+ weight = weight.view(1, -1, 1)
+ ip_hidden_states = ip_hidden_states * weight
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ scale = self.scale
+ hidden_states = hidden_states + scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+ # this ensures that the ip_scaler is not changed when we load the model
+ # def _apply(self, fn):
+ # if hasattr(self, "ip_scaler"):
+ # # Overriding the _apply method to prevent the special_parameter from changing dtype
+ # self.ip_scaler = fn(self.ip_scaler)
+ # # Temporarily set the special_parameter to None to exclude it from default _apply processing
+ # ip_scaler = self.ip_scaler
+ # self.ip_scaler = None
+ # super(CustomIPAttentionProcessor, self)._apply(fn)
+ # # Restore the special_parameter after the default _apply processing
+ # self.ip_scaler = ip_scaler
+ # return self
+ # else:
+ # return super(CustomIPAttentionProcessor, self)._apply(fn)
+
+
+class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False,
+ full_token_scaler=False):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.train_scaler = train_scaler
+ self.num_tokens = num_tokens
+ if train_scaler:
+ if full_token_scaler:
+ self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
+ else:
+ self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
+ # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
+ self.ip_scaler.requires_grad_(True)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ is_active = self.adapter_ref().is_active
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # begin ip adapter
+ if not is_active:
+ ip_hidden_states = None
+ else:
+ # get ip hidden states. Should be stored
+ ip_hidden_states = self.adapter_ref().last_conditional
+ # add unconditional to front if it exists
+ if ip_hidden_states.shape[0] * 2 == batch_size:
+ if self.adapter_ref().last_unconditional is None:
+ raise ValueError("Unconditional is None but should not be")
+ ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0)
+
+ if ip_hidden_states is not None:
+ # apply scaler
+ if self.train_scaler:
+ weight = self.ip_scaler
+ # reshape to (1, self.num_tokens, 1)
+ weight = weight.view(1, -1, 1)
+ ip_hidden_states = ip_hidden_states * weight
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ scale = self.scale
+ hidden_states = hidden_states + scale * ip_hidden_states
+ # end ip adapter
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+
+ def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
+ super().__init__()
+ self.config = adapter_config
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.device = self.sd_ref().unet.device
+ self.preprocessor: Optional[CLIPImagePreProcessor] = None
+ self.input_size = 224
+ self.clip_noise_zero = True
+ self.unconditional: torch.Tensor = None
+
+ self.last_conditional: torch.Tensor = None
+ self.last_unconditional: torch.Tensor = None
+
+ self.additional_loss = None
+ if self.config.image_encoder_arch.startswith("clip"):
+ try:
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.clip_image_processor = CLIPImageProcessor()
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ adapter_config.image_encoder_path,
+ ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'siglip':
+ from transformers import SiglipImageProcessor, SiglipVisionModel
+ try:
+ self.clip_image_processor = SiglipImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.clip_image_processor = SiglipImageProcessor()
+ self.image_encoder = SiglipVisionModel.from_pretrained(
+ adapter_config.image_encoder_path,
+ ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'vit':
+ try:
+ self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.clip_image_processor = ViTFeatureExtractor()
+ self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'safe':
+ try:
+ self.clip_image_processor = SAFEImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ self.clip_image_processor = SAFEImageProcessor()
+ self.image_encoder = SAFEVisionModel(
+ in_channels=3,
+ num_tokens=self.config.safe_tokens,
+ num_vectors=sd.unet.config['cross_attention_dim'],
+ reducer_channels=self.config.safe_reducer_channels,
+ channels=self.config.safe_channels,
+ downscale_factor=8
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'convnext':
+ try:
+ self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ print(f"could not load image processor from {adapter_config.image_encoder_path}")
+ self.clip_image_processor = ConvNextImageProcessor(
+ size=320,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ )
+ self.image_encoder = ConvNextForImageClassification.from_pretrained(
+ adapter_config.image_encoder_path,
+ use_safetensors=True,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'convnextv2':
+ try:
+ self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ print(f"could not load image processor from {adapter_config.image_encoder_path}")
+ self.clip_image_processor = ConvNextImageProcessor(
+ size=512,
+ image_mean=[0.485, 0.456, 0.406],
+ image_std=[0.229, 0.224, 0.225],
+ )
+ self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
+ adapter_config.image_encoder_path,
+ use_safetensors=True,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ elif self.config.image_encoder_arch == 'vit-hybrid':
+ try:
+ self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
+ except EnvironmentError:
+ print(f"could not load image processor from {adapter_config.image_encoder_path}")
+ self.clip_image_processor = ViTHybridImageProcessor(
+ size=320,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ )
+ self.image_encoder = ViTHybridForImageClassification.from_pretrained(
+ adapter_config.image_encoder_path,
+ use_safetensors=True,
+ ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ else:
+ raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}")
+
+ if not self.config.train_image_encoder:
+ # compile it
+ print('Compiling image encoder')
+ #torch.compile(self.image_encoder, fullgraph=True)
+
+ self.input_size = self.image_encoder.config.image_size
+
+ if self.config.quad_image: # 4x4 image
+ # self.clip_image_processor.config
+ # We do a 3x downscale of the image, so we need to adjust the input size
+ preprocessor_input_size = self.image_encoder.config.image_size * 2
+
+ # update the preprocessor so images come in at the right size
+ if 'height' in self.clip_image_processor.size:
+ self.clip_image_processor.size['height'] = preprocessor_input_size
+ self.clip_image_processor.size['width'] = preprocessor_input_size
+ elif hasattr(self.clip_image_processor, 'crop_size'):
+ self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
+ self.clip_image_processor.crop_size['height'] = preprocessor_input_size
+ self.clip_image_processor.crop_size['width'] = preprocessor_input_size
+
+ if self.config.image_encoder_arch == 'clip+':
+ # self.clip_image_processor.config
+ # We do a 3x downscale of the image, so we need to adjust the input size
+ preprocessor_input_size = self.image_encoder.config.image_size * 4
+
+ # update the preprocessor so images come in at the right size
+ self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
+ self.clip_image_processor.crop_size['height'] = preprocessor_input_size
+ self.clip_image_processor.crop_size['width'] = preprocessor_input_size
+
+ self.preprocessor = CLIPImagePreProcessor(
+ input_size=preprocessor_input_size,
+ clip_input_size=self.image_encoder.config.image_size,
+ )
+ if not self.config.image_encoder_arch == 'safe':
+ if 'height' in self.clip_image_processor.size:
+ self.input_size = self.clip_image_processor.size['height']
+ elif hasattr(self.clip_image_processor, 'crop_size'):
+ self.input_size = self.clip_image_processor.crop_size['height']
+ elif 'shortest_edge' in self.clip_image_processor.size.keys():
+ self.input_size = self.clip_image_processor.size['shortest_edge']
+ else:
+ raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
+ self.current_scale = 1.0
+ self.is_active = True
+ is_pixart = sd.is_pixart
+ is_flux = sd.is_flux
+ if adapter_config.type == 'ip':
+ # ip-adapter
+ image_proj_model = ImageProjModel(
+ cross_attention_dim=sd.unet.config['cross_attention_dim'],
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
+ clip_extra_context_tokens=self.config.num_tokens, # usually 4
+ )
+ elif adapter_config.type == 'ip_clip_face':
+ cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
+ image_proj_model = MLPProjModelClipFace(
+ cross_attention_dim=cross_attn_dim,
+ id_embeddings_dim=self.image_encoder.config.projection_dim,
+ num_tokens=self.config.num_tokens, # usually 4
+ )
+ elif adapter_config.type == 'ip+':
+ heads = 12 if not sd.is_xl else 20
+ if is_flux:
+ dim = 1280
+ else:
+ dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
+ embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
+ 'convnext') else \
+ self.image_encoder.config.hidden_sizes[-1]
+
+ image_encoder_state_dict = self.image_encoder.state_dict()
+ # max_seq_len = CLIP tokens + CLS token
+ max_seq_len = 257
+ if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
+ # clip
+ max_seq_len = int(
+ image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
+
+ if is_pixart:
+ heads = 20
+ dim = 1280
+ output_dim = 4096
+ elif is_flux:
+ heads = 20
+ dim = 1280
+ output_dim = 3072
+ else:
+ output_dim = sd.unet.config['cross_attention_dim']
+
+ if self.config.image_encoder_arch.startswith('convnext'):
+ in_tokens = 16 * 16
+ embedding_dim = self.image_encoder.config.hidden_sizes[-1]
+
+ # ip-adapter-plus
+ image_proj_model = Resampler(
+ dim=dim,
+ depth=4,
+ dim_head=64,
+ heads=heads,
+ num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
+ embedding_dim=embedding_dim,
+ max_seq_len=max_seq_len,
+ output_dim=output_dim,
+ ff_mult=4
+ )
+ elif adapter_config.type == 'ipz':
+ dim = sd.unet.config['cross_attention_dim']
+ if hasattr(self.image_encoder.config, 'hidden_sizes'):
+ embedding_dim = self.image_encoder.config.hidden_sizes[-1]
+ else:
+ embedding_dim = self.image_encoder.config.target_hidden_size
+
+ image_encoder_state_dict = self.image_encoder.state_dict()
+ # max_seq_len = CLIP tokens + CLS token
+ in_tokens = 257
+ if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
+ # clip
+ in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
+
+ if self.config.image_encoder_arch.startswith('convnext'):
+ in_tokens = 16 * 16
+ embedding_dim = self.image_encoder.config.hidden_sizes[-1]
+
+ is_conv_next = self.config.image_encoder_arch.startswith('convnext')
+
+ out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens
+ # ip-adapter-plus
+ image_proj_model = ZipperResampler(
+ in_size=embedding_dim,
+ in_tokens=in_tokens,
+ out_size=dim,
+ out_tokens=out_tokens,
+ hidden_size=embedding_dim,
+ hidden_tokens=in_tokens,
+ # num_blocks=1 if not is_conv_next else 2,
+ num_blocks=1 if not is_conv_next else 2,
+ is_conv_input=is_conv_next
+ )
+ elif adapter_config.type == 'ilora':
+ # we apply the clip encodings to the LoRA
+ image_proj_model = None
+ else:
+ raise ValueError(f"unknown adapter type: {adapter_config.type}")
+
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+ attn_processor_keys = []
+ if is_pixart:
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
+
+ # cross attention
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
+ elif is_flux:
+ transformer: FluxTransformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn")
+
+ # single transformer blocks do not have cross attn, but we will do them anyway
+ for i, module in transformer.single_transformer_blocks.named_children():
+ attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
+ else:
+ attn_processor_keys = list(sd.unet.attn_processors.keys())
+
+ attn_processor_names = []
+
+ blocks = []
+ transformer_blocks = []
+ for name in attn_processor_keys:
+ name_split = name.split(".")
+ block_name = f"{name_split[0]}.{name_split[1]}"
+ transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
+ if transformer_idx >= 0:
+ transformer_name = ".".join(name_split[:2])
+ transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
+ if transformer_name not in transformer_blocks:
+ transformer_blocks.append(transformer_name)
+
+
+ if block_name not in blocks:
+ blocks.append(block_name)
+ if is_flux:
+ cross_attention_dim = None
+ else:
+ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
+ sd.unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = sd.unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = sd.unet.config['block_out_channels'][block_id]
+ elif name.startswith("transformer") or name.startswith("single_transformer"):
+ if is_flux:
+ hidden_size = 3072
+ else:
+ hidden_size = sd.unet.config['cross_attention_dim']
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None and not is_flux:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ layer_name = name.split(".processor")[0]
+
+ # if quantized, we need to scale the weights
+ if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
+ # is quantized
+
+ k_weight = torch.randn(hidden_size, hidden_size) * 0.01
+ v_weight = torch.randn(hidden_size, hidden_size) * 0.01
+ k_weight = k_weight.to(self.sd_ref().torch_dtype)
+ v_weight = v_weight.to(self.sd_ref().torch_dtype)
+ else:
+ k_weight = unet_sd[layer_name + ".to_k.weight"]
+ v_weight = unet_sd[layer_name + ".to_v.weight"]
+
+ weights = {
+ "to_k_ip.weight": k_weight,
+ "to_v_ip.weight": v_weight
+ }
+
+ if is_flux:
+ attn_procs[name] = CustomIPFluxAttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.config.num_tokens,
+ adapter=self,
+ train_scaler=self.config.train_scaler or self.config.merge_scaler,
+ full_token_scaler=False
+ )
+ else:
+ attn_procs[name] = CustomIPAttentionProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.config.num_tokens,
+ adapter=self,
+ train_scaler=self.config.train_scaler or self.config.merge_scaler,
+ # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
+ full_token_scaler=False
+ )
+ if self.sd_ref().is_pixart or self.sd_ref().is_flux:
+ # pixart is much more sensitive
+ weights = {
+ "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01,
+ "to_v_ip.weight": weights["to_v_ip.weight"] * 0.01,
+ }
+
+ attn_procs[name].load_state_dict(weights, strict=False)
+ attn_processor_names.append(name)
+ print(f"Attn Processors")
+ print(attn_processor_names)
+ if self.sd_ref().is_pixart:
+ # we have to set them ourselves
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
+ module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
+ self.adapter_modules = torch.nn.ModuleList(
+ [
+ transformer.transformer_blocks[i].attn2.processor for i in
+ range(len(transformer.transformer_blocks))
+ ])
+ elif self.sd_ref().is_flux:
+ # we have to set them ourselves
+ transformer: FluxTransformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
+
+ # do single blocks too even though they dont have cross attn
+ for i, module in transformer.single_transformer_blocks.named_children():
+ module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
+
+ self.adapter_modules = torch.nn.ModuleList(
+ [
+ transformer.transformer_blocks[i].attn.processor for i in
+ range(len(transformer.transformer_blocks))
+ ] + [
+ transformer.single_transformer_blocks[i].attn.processor for i in
+ range(len(transformer.single_transformer_blocks))
+ ]
+ )
+ else:
+ sd.unet.set_attn_processor(attn_procs)
+ self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
+
+ sd.adapter = self
+ self.unet_ref: weakref.ref = weakref.ref(sd.unet)
+ self.image_proj_model = image_proj_model
+ # load the weights if we have some
+ if self.config.name_or_path:
+ loaded_state_dict = load_ip_adapter_model(
+ self.config.name_or_path,
+ device='cpu',
+ dtype=sd.torch_dtype
+ )
+ self.load_state_dict(loaded_state_dict)
+
+ self.set_scale(1.0)
+
+ if self.config.train_image_encoder:
+ self.image_encoder.train()
+ self.image_encoder.requires_grad_(True)
+
+ # premake a unconditional
+ zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16)
+ self.unconditional = self.clip_image_processor(
+ images=zerod,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.image_encoder.to(*args, **kwargs)
+ self.image_proj_model.to(*args, **kwargs)
+ self.adapter_modules.to(*args, **kwargs)
+ if self.preprocessor is not None:
+ self.preprocessor.to(*args, **kwargs)
+ return self
+
+ # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
+ # self.image_proj_model.load_state_dict(state_dict["image_proj"])
+ # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ # ip_layers.load_state_dict(state_dict["ip_adapter"])
+ # if self.config.train_image_encoder and 'image_encoder' in state_dict:
+ # self.image_encoder.load_state_dict(state_dict["image_encoder"])
+ # if self.preprocessor is not None and 'preprocessor' in state_dict:
+ # self.preprocessor.load_state_dict(state_dict["preprocessor"])
+
+ # def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
+ # self.load_ip_adapter(state_dict)
+
+ def state_dict(self) -> OrderedDict:
+ state_dict = OrderedDict()
+ if self.config.train_only_image_encoder:
+ return self.image_encoder.state_dict()
+ if self.config.train_scaler:
+ state_dict["ip_scale"] = self.adapter_modules.state_dict()
+ # remove items that are not scalers
+ for key in list(state_dict["ip_scale"].keys()):
+ if not key.endswith("ip_scaler"):
+ del state_dict["ip_scale"][key]
+ return state_dict
+
+ state_dict["image_proj"] = self.image_proj_model.state_dict()
+ state_dict["ip_adapter"] = self.adapter_modules.state_dict()
+ # handle merge scaler training
+ if self.config.merge_scaler:
+ for key in list(state_dict["ip_adapter"].keys()):
+ if key.endswith("ip_scaler"):
+ # merge in the scaler so we dont have to save it and it will be compatible with other ip adapters
+ scale = state_dict["ip_adapter"][key].clone()
+
+ key_start = key.split(".")[-2]
+ # reshape to (1, 1)
+ scale = scale.view(1, 1)
+ del state_dict["ip_adapter"][key]
+ # find the to_k_ip and to_v_ip keys
+ for key2 in list(state_dict["ip_adapter"].keys()):
+ if key2.endswith(f"{key_start}.to_k_ip.weight"):
+ state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale
+ if key2.endswith(f"{key_start}.to_v_ip.weight"):
+ state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale
+
+ if self.config.train_image_encoder:
+ state_dict["image_encoder"] = self.image_encoder.state_dict()
+ if self.preprocessor is not None:
+ state_dict["preprocessor"] = self.preprocessor.state_dict()
+ return state_dict
+
+ def get_scale(self):
+ return self.current_scale
+
+ def set_scale(self, scale):
+ self.current_scale = scale
+ if not self.sd_ref().is_pixart and not self.sd_ref().is_flux:
+ for attn_processor in self.sd_ref().unet.attn_processors.values():
+ if isinstance(attn_processor, CustomIPAttentionProcessor):
+ attn_processor.scale = scale
+
+ # @torch.no_grad()
+ # def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
+ # drop=False) -> torch.Tensor:
+ # # todo: add support for sdxl
+ # if isinstance(pil_image, Image.Image):
+ # pil_image = [pil_image]
+ # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
+ # clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ # if drop:
+ # clip_image = clip_image * 0
+ # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
+ # return clip_image_embeds
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.image_encoder.to(*args, **kwargs)
+ self.image_proj_model.to(*args, **kwargs)
+ self.adapter_modules.to(*args, **kwargs)
+ if self.preprocessor is not None:
+ self.preprocessor.to(*args, **kwargs)
+ return self
+
+ def parse_clip_image_embeds_from_cache(
+ self,
+ image_embeds_list: List[dict], # has ['last_hidden_state', 'image_embeds', 'penultimate_hidden_states']
+ quad_count=4,
+ ):
+ with torch.no_grad():
+ device = self.sd_ref().unet.device
+ clip_image_embeds = torch.cat([x[self.config.clip_layer] for x in image_embeds_list], dim=0)
+
+ if self.config.quad_image:
+ # get the outputs of the quat
+ chunks = clip_image_embeds.chunk(quad_count, dim=0)
+ chunk_sum = torch.zeros_like(chunks[0])
+ for chunk in chunks:
+ chunk_sum = chunk_sum + chunk
+ # get the mean of them
+
+ clip_image_embeds = chunk_sum / quad_count
+
+ clip_image_embeds = clip_image_embeds.to(device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
+ return clip_image_embeds
+
+ def get_empty_clip_image(self, batch_size: int) -> torch.Tensor:
+ with torch.no_grad():
+ tensors_0_1 = torch.rand([batch_size, 3, self.input_size, self.input_size], device=self.device)
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ # tensors_0_1 = tensors_0_1 * 0
+ mean = torch.tensor(self.clip_image_processor.image_mean).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ std = torch.tensor(self.clip_image_processor.image_std).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
+ clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
+ return clip_image.detach()
+
+ def get_clip_image_embeds_from_tensors(
+ self,
+ tensors_0_1: torch.Tensor,
+ drop=False,
+ is_training=False,
+ has_been_preprocessed=False,
+ quad_count=4,
+ cfg_embed_strength=None, # perform CFG on embeds with unconditional as negative
+ ) -> torch.Tensor:
+ if self.sd_ref().unet.device != self.device:
+ self.to(self.sd_ref().unet.device)
+ if self.sd_ref().unet.device != self.image_encoder.device:
+ self.to(self.sd_ref().unet.device)
+ if not self.config.train:
+ is_training = False
+ uncond_clip = None
+ with torch.no_grad():
+ # on training the clip image is created in the dataloader
+ if not has_been_preprocessed:
+ # tensors should be 0-1
+ if tensors_0_1.ndim == 3:
+ tensors_0_1 = tensors_0_1.unsqueeze(0)
+ # training tensors are 0 - 1
+ tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
+
+ # if images are out of this range throw error
+ if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
+ raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
+ tensors_0_1.min(), tensors_0_1.max()
+ ))
+ # unconditional
+ if drop:
+ if self.clip_noise_zero:
+ tensors_0_1 = torch.rand_like(tensors_0_1).detach()
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ else:
+ tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
+ # tensors_0_1 = tensors_0_1 * 0
+ clip_image = self.clip_image_processor(
+ images=tensors_0_1,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ ).pixel_values
+ else:
+ if drop:
+ # scale the noise down
+ if self.clip_noise_zero:
+ tensors_0_1 = torch.rand_like(tensors_0_1).detach()
+ noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
+ dtype=get_torch_dtype(self.sd_ref().dtype))
+ tensors_0_1 = tensors_0_1 * noise_scale
+ else:
+ tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
+ # tensors_0_1 = tensors_0_1 * 0
+ mean = torch.tensor(self.clip_image_processor.image_mean).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ std = torch.tensor(self.clip_image_processor.image_std).to(
+ self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
+ ).detach()
+ tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
+ clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
+
+ else:
+ clip_image = tensors_0_1
+ clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
+
+ if self.config.quad_image:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = clip_image.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ to_cat = []
+ for i, ci in enumerate([ci1, ci2, ci3, ci4]):
+ if i < quad_count:
+ to_cat.append(ci)
+ else:
+ break
+
+ clip_image = torch.cat(to_cat, dim=0).detach()
+
+ # if drop:
+ # clip_image = clip_image * 0
+ with torch.set_grad_enabled(is_training):
+ if is_training and self.config.train_image_encoder:
+ self.image_encoder.train()
+ clip_image = clip_image.requires_grad_(True)
+ if self.preprocessor is not None:
+ clip_image = self.preprocessor(clip_image)
+ clip_output = self.image_encoder(
+ clip_image,
+ output_hidden_states=True
+ )
+ else:
+ self.image_encoder.eval()
+ if self.preprocessor is not None:
+ clip_image = self.preprocessor(clip_image)
+ clip_output = self.image_encoder(
+ clip_image, output_hidden_states=True
+ )
+
+ if self.config.clip_layer == 'penultimate_hidden_states':
+ # they skip last layer for ip+
+ # https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
+ clip_image_embeds = clip_output.hidden_states[-2]
+ elif self.config.clip_layer == 'last_hidden_state':
+ clip_image_embeds = clip_output.hidden_states[-1]
+ else:
+ clip_image_embeds = clip_output.image_embeds
+
+ if self.config.adapter_type == "clip_face":
+ l2_norm = torch.norm(clip_image_embeds, p=2)
+ clip_image_embeds = clip_image_embeds / l2_norm
+
+ if self.config.image_encoder_arch.startswith('convnext'):
+ # flatten the width height layers to make the token space
+ clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
+ # rearrange to (batch, tokens, size)
+ clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
+
+ # apply unconditional if doing cfg on embeds
+ with torch.no_grad():
+ if cfg_embed_strength is not None:
+ uncond_clip = self.get_empty_clip_image(tensors_0_1.shape[0]).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ if self.config.quad_image:
+ # split the 4x4 grid and stack on batch
+ ci1, ci2 = uncond_clip.chunk(2, dim=2)
+ ci1, ci3 = ci1.chunk(2, dim=3)
+ ci2, ci4 = ci2.chunk(2, dim=3)
+ to_cat = []
+ for i, ci in enumerate([ci1, ci2, ci3, ci4]):
+ if i < quad_count:
+ to_cat.append(ci)
+ else:
+ break
+
+ uncond_clip = torch.cat(to_cat, dim=0).detach()
+ uncond_clip_output = self.image_encoder(
+ uncond_clip, output_hidden_states=True
+ )
+
+ if self.config.clip_layer == 'penultimate_hidden_states':
+ uncond_clip_output_embeds = uncond_clip_output.hidden_states[-2]
+ elif self.config.clip_layer == 'last_hidden_state':
+ uncond_clip_output_embeds = uncond_clip_output.hidden_states[-1]
+ else:
+ uncond_clip_output_embeds = uncond_clip_output.image_embeds
+ if self.config.adapter_type == "clip_face":
+ l2_norm = torch.norm(uncond_clip_output_embeds, p=2)
+ uncond_clip_output_embeds = uncond_clip_output_embeds / l2_norm
+
+ uncond_clip_output_embeds = uncond_clip_output_embeds.detach()
+
+
+ # apply inverse cfg
+ clip_image_embeds = inverse_classifier_guidance(
+ clip_image_embeds,
+ uncond_clip_output_embeds,
+ cfg_embed_strength
+ )
+
+
+ if self.config.quad_image:
+ # get the outputs of the quat
+ chunks = clip_image_embeds.chunk(quad_count, dim=0)
+ if self.config.train_image_encoder and is_training:
+ # perform a loss across all chunks this will teach the vision encoder to
+ # identify similarities in our pairs of images and ignore things that do not make them similar
+ num_losses = 0
+ total_loss = None
+ for chunk in chunks:
+ for chunk2 in chunks:
+ if chunk is not chunk2:
+ loss = F.mse_loss(chunk, chunk2)
+ if total_loss is None:
+ total_loss = loss
+ else:
+ total_loss = total_loss + loss
+ num_losses += 1
+ if total_loss is not None:
+ total_loss = total_loss / num_losses
+ total_loss = total_loss * 1e-2
+ if self.additional_loss is not None:
+ total_loss = total_loss + self.additional_loss
+ self.additional_loss = total_loss
+
+ chunk_sum = torch.zeros_like(chunks[0])
+ for chunk in chunks:
+ chunk_sum = chunk_sum + chunk
+ # get the mean of them
+
+ clip_image_embeds = chunk_sum / quad_count
+
+ if not is_training or not self.config.train_image_encoder:
+ clip_image_embeds = clip_image_embeds.detach()
+
+ return clip_image_embeds
+
+ # use drop for prompt dropout, or negatives
+ def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor, is_unconditional=False) -> PromptEmbeds:
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
+ if self.sd_ref().is_flux:
+ # do not attach to text embeds for flux, we will save and grab them as it messes
+ # with the RoPE to have them in the same tensor
+ if is_unconditional:
+ self.last_unconditional = image_prompt_embeds
+ else:
+ self.last_conditional = image_prompt_embeds
+ else:
+ embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
+ return embeddings
+
+ def train(self: T, mode: bool = True) -> T:
+ if self.config.train_image_encoder:
+ self.image_encoder.train(mode)
+ if not self.config.train_only_image_encoder:
+ for attn_processor in self.adapter_modules:
+ attn_processor.train(mode)
+ if self.image_proj_model is not None:
+ self.image_proj_model.train(mode)
+ return super().train(mode)
+
+ def get_parameter_groups(self, adapter_lr):
+ param_groups = []
+ # when training just scaler, we do not train anything else
+ if not self.config.train_scaler:
+ param_groups.append({
+ "params": list(self.get_non_scaler_parameters()),
+ "lr": adapter_lr,
+ })
+ if self.config.train_scaler or self.config.merge_scaler:
+ scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
+ param_groups.append({
+ "params": list(self.get_scaler_parameters()),
+ "lr": scaler_lr,
+ })
+ return param_groups
+
+ def get_scaler_parameters(self):
+ # only get the scalera from the adapter modules
+ for attn_processor in self.adapter_modules:
+ # only get the scaler
+ # check if it has ip_scaler attribute
+ if hasattr(attn_processor, "ip_scaler"):
+ scaler_param = attn_processor.ip_scaler
+ yield scaler_param
+
+ def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
+ if self.config.train_only_image_encoder:
+ if self.config.train_only_image_encoder_positional_embedding:
+ yield from self.image_encoder.vision_model.embeddings.position_embedding.parameters(recurse)
+ else:
+ yield from self.image_encoder.parameters(recurse)
+ return
+ if self.config.train_scaler:
+ # no params
+ return
+
+ for attn_processor in self.adapter_modules:
+ if self.config.train_scaler or self.config.merge_scaler:
+ # todo remove scaler
+ if hasattr(attn_processor, "to_k_ip"):
+ # yield the linear layer
+ yield from attn_processor.to_k_ip.parameters(recurse)
+ if hasattr(attn_processor, "to_v_ip"):
+ # yield the linear layer
+ yield from attn_processor.to_v_ip.parameters(recurse)
+ else:
+ yield from attn_processor.parameters(recurse)
+ yield from self.image_proj_model.parameters(recurse)
+ if self.config.train_image_encoder:
+ yield from self.image_encoder.parameters(recurse)
+ if self.preprocessor is not None:
+ yield from self.preprocessor.parameters(recurse)
+
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
+ yield from self.get_non_scaler_parameters(recurse)
+ if self.config.train_scaler or self.config.merge_scaler:
+ yield from self.get_scaler_parameters()
+
+ def merge_in_weights(self, state_dict: Mapping[str, Any]):
+ # merge in img_proj weights
+ current_img_proj_state_dict = self.image_proj_model.state_dict()
+ for key, value in state_dict["image_proj"].items():
+ if key in current_img_proj_state_dict:
+ current_shape = current_img_proj_state_dict[key].shape
+ new_shape = value.shape
+ if current_shape != new_shape:
+ try:
+ # merge in what we can and leave the other values as they are
+ if len(current_shape) == 1:
+ current_img_proj_state_dict[key][:new_shape[0]] = value
+ elif len(current_shape) == 2:
+ current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
+ elif len(current_shape) == 3:
+ current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
+ elif len(current_shape) == 4:
+ current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
+ :new_shape[3]] = value
+ else:
+ raise ValueError(f"unknown shape: {current_shape}")
+ except RuntimeError as e:
+ print(e)
+ print(
+ f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
+
+ if len(current_shape) == 1:
+ current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
+ elif len(current_shape) == 2:
+ current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[
+ :current_shape[0],
+ :current_shape[1]]
+ elif len(current_shape) == 3:
+ current_img_proj_state_dict[key][:current_shape[0], :current_shape[1],
+ :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
+ elif len(current_shape) == 4:
+ current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
+ :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
+ :current_shape[3]]
+ else:
+ raise ValueError(f"unknown shape: {current_shape}")
+ print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
+ else:
+ current_img_proj_state_dict[key] = value
+ self.image_proj_model.load_state_dict(current_img_proj_state_dict)
+
+ # merge in ip adapter weights
+ current_ip_adapter_state_dict = self.adapter_modules.state_dict()
+ for key, value in state_dict["ip_adapter"].items():
+ if key in current_ip_adapter_state_dict:
+ current_shape = current_ip_adapter_state_dict[key].shape
+ new_shape = value.shape
+ if current_shape != new_shape:
+ try:
+ # merge in what we can and leave the other values as they are
+ if len(current_shape) == 1:
+ current_ip_adapter_state_dict[key][:new_shape[0]] = value
+ elif len(current_shape) == 2:
+ current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
+ elif len(current_shape) == 3:
+ current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
+ elif len(current_shape) == 4:
+ current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
+ :new_shape[3]] = value
+ else:
+ raise ValueError(f"unknown shape: {current_shape}")
+ print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
+ except RuntimeError as e:
+ print(e)
+ print(
+ f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
+
+ if (len(current_shape) == 1):
+ current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
+ elif (len(current_shape) == 2):
+ current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[
+ :current_shape[
+ 0],
+ :current_shape[
+ 1]]
+ elif (len(current_shape) == 3):
+ current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1],
+ :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
+ elif (len(current_shape) == 4):
+ current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
+ :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
+ :current_shape[3]]
+ else:
+ raise ValueError(f"unknown shape: {current_shape}")
+ print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
+
+ else:
+ current_ip_adapter_state_dict[key] = value
+ self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
+
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+ strict = False
+ if self.config.train_scaler and 'ip_scale' in state_dict:
+ self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False)
+ if 'ip_adapter' in state_dict:
+ try:
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
+ self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
+ except Exception as e:
+ print(e)
+ print("could not load ip adapter weights, trying to merge in weights")
+ self.merge_in_weights(state_dict)
+ if self.config.train_image_encoder and 'image_encoder' in state_dict:
+ self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
+ if self.preprocessor is not None and 'preprocessor' in state_dict:
+ self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
+
+ if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict:
+ # we are loading pure clip weights.
+ self.image_encoder.load_state_dict(state_dict, strict=strict)
+
+ def enable_gradient_checkpointing(self):
+ if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
+ self.image_encoder.enable_gradient_checkpointing()
+ elif hasattr(self.image_encoder, 'gradient_checkpointing'):
+ self.image_encoder.gradient_checkpointing = True
diff --git a/toolkit/job.py b/toolkit/job.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc274fb798efb5780056bd2134eb2940a608a98c
--- /dev/null
+++ b/toolkit/job.py
@@ -0,0 +1,44 @@
+from typing import Union, OrderedDict
+
+from toolkit.config import get_config
+
+
+def get_job(
+ config_path: Union[str, dict, OrderedDict],
+ name=None
+):
+ config = get_config(config_path, name)
+ if not config['job']:
+ raise ValueError('config file is invalid. Missing "job" key')
+
+ job = config['job']
+ if job == 'extract':
+ from jobs import ExtractJob
+ return ExtractJob(config)
+ if job == 'train':
+ from jobs import TrainJob
+ return TrainJob(config)
+ if job == 'mod':
+ from jobs import ModJob
+ return ModJob(config)
+ if job == 'generate':
+ from jobs import GenerateJob
+ return GenerateJob(config)
+ if job == 'extension':
+ from jobs import ExtensionJob
+ return ExtensionJob(config)
+
+ # elif job == 'train':
+ # from jobs import TrainJob
+ # return TrainJob(config)
+ else:
+ raise ValueError(f'Unknown job type {job}')
+
+
+def run_job(
+ config: Union[str, dict, OrderedDict],
+ name=None
+):
+ job = get_job(config, name)
+ job.run()
+ job.cleanup()
diff --git a/toolkit/keymaps/stable_diffusion_refiner.json b/toolkit/keymaps/stable_diffusion_refiner.json
new file mode 100644
index 0000000000000000000000000000000000000000..4c7525d8804da9ec92b7f87bc01741d4372ac83d
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_refiner.json
@@ -0,0 +1,3498 @@
+{
+ "ldm_diffusers_keymap": {
+ "conditioner.embedders.0.model.ln_final.bias": "te1_text_model.final_layer_norm.bias",
+ "conditioner.embedders.0.model.ln_final.weight": "te1_text_model.final_layer_norm.weight",
+ "conditioner.embedders.0.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.0.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight"
+ ]
+ }
+ },
+ "diffusers_ldm_operator_map": {
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.0.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.1.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.10.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.11.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.12.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.13.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.14.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.15.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.16.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.17.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.18.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.19.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.2.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.20.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.21.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.22.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.23.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.24.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.25.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.26.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.27.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.28.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.29.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.3.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.30.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.31.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.4.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.5.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.6.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.7.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.8.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.0.model.transformer.resblocks.9.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..02e3ebb921777760664f8073d2131f2503b60d81
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_refiner_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9d9f0fd82268e59d252443653a6e3c30ef4d30fb0b418fadadcebc2572608aa
+size 3277018
diff --git a/toolkit/keymaps/stable_diffusion_refiner_unmatched.json b/toolkit/keymaps/stable_diffusion_refiner_unmatched.json
new file mode 100644
index 0000000000000000000000000000000000000000..cb5aba0a543c8ad50094abb3f99e266336908aaa
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_refiner_unmatched.json
@@ -0,0 +1,27 @@
+{
+ "ldm": {
+ "conditioner.embedders.0.model.logit_scale": {
+ "shape": [],
+ "min": 4.60546875,
+ "max": 4.60546875
+ },
+ "conditioner.embedders.0.model.text_projection": {
+ "shape": [
+ 1280,
+ 1280
+ ],
+ "min": -0.15966796875,
+ "max": 0.230712890625
+ }
+ },
+ "diffusers": {
+ "te1_text_projection.weight": {
+ "shape": [
+ 1280,
+ 1280
+ ],
+ "min": -0.15966796875,
+ "max": 0.230712890625
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_sd1.json b/toolkit/keymaps/stable_diffusion_sd1.json
new file mode 100644
index 0000000000000000000000000000000000000000..8f04f753ac6656fdc2a2d44d8d07ebc7db184689
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sd1.json
@@ -0,0 +1,1234 @@
+{
+ "ldm_diffusers_keymap": {
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "te_text_model.embeddings.position_embedding.weight",
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te_text_model.encoder.layers.0.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te_text_model.encoder.layers.0.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te_text_model.encoder.layers.1.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te_text_model.encoder.layers.1.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te_text_model.encoder.layers.10.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te_text_model.encoder.layers.10.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te_text_model.encoder.layers.11.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te_text_model.encoder.layers.11.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te_text_model.encoder.layers.2.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te_text_model.encoder.layers.2.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te_text_model.encoder.layers.3.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te_text_model.encoder.layers.3.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te_text_model.encoder.layers.4.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te_text_model.encoder.layers.4.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te_text_model.encoder.layers.5.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te_text_model.encoder.layers.5.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te_text_model.encoder.layers.6.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te_text_model.encoder.layers.6.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te_text_model.encoder.layers.7.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te_text_model.encoder.layers.7.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te_text_model.encoder.layers.8.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te_text_model.encoder.layers.8.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te_text_model.encoder.layers.9.self_attn.v_proj.bias",
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te_text_model.encoder.layers.9.self_attn.v_proj.weight",
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "te_text_model.final_layer_norm.bias",
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "te_text_model.final_layer_norm.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {},
+ "diffusers_ldm_operator_map": {}
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..8e2c4cb90b8d10d6c9a844a3b73ef3e07541f130
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sd1_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
+size 16
diff --git a/toolkit/keymaps/stable_diffusion_sd2.json b/toolkit/keymaps/stable_diffusion_sd2.json
new file mode 100644
index 0000000000000000000000000000000000000000..868facaf5b6119f5d3a82d369fe509b82da1f551
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sd2.json
@@ -0,0 +1,2424 @@
+{
+ "ldm_diffusers_keymap": {
+ "cond_stage_model.model.ln_final.bias": "te_text_model.final_layer_norm.bias",
+ "cond_stage_model.model.ln_final.weight": "te_text_model.final_layer_norm.weight",
+ "cond_stage_model.model.positional_embedding": "te_text_model.embeddings.position_embedding.weight",
+ "cond_stage_model.model.token_embedding.weight": "te_text_model.embeddings.token_embedding.weight",
+ "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.bias": "te_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.0.attn.out_proj.weight": "te_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.0.ln_1.bias": "te_text_model.encoder.layers.0.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.0.ln_1.weight": "te_text_model.encoder.layers.0.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.0.ln_2.bias": "te_text_model.encoder.layers.0.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.0.ln_2.weight": "te_text_model.encoder.layers.0.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.bias": "te_text_model.encoder.layers.0.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.0.mlp.c_fc.weight": "te_text_model.encoder.layers.0.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.bias": "te_text_model.encoder.layers.0.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.0.mlp.c_proj.weight": "te_text_model.encoder.layers.0.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.bias": "te_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.1.attn.out_proj.weight": "te_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.1.ln_1.bias": "te_text_model.encoder.layers.1.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.1.ln_1.weight": "te_text_model.encoder.layers.1.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.1.ln_2.bias": "te_text_model.encoder.layers.1.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.1.ln_2.weight": "te_text_model.encoder.layers.1.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.bias": "te_text_model.encoder.layers.1.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.1.mlp.c_fc.weight": "te_text_model.encoder.layers.1.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.bias": "te_text_model.encoder.layers.1.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.1.mlp.c_proj.weight": "te_text_model.encoder.layers.1.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.bias": "te_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.10.attn.out_proj.weight": "te_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.10.ln_1.bias": "te_text_model.encoder.layers.10.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.10.ln_1.weight": "te_text_model.encoder.layers.10.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.10.ln_2.bias": "te_text_model.encoder.layers.10.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.10.ln_2.weight": "te_text_model.encoder.layers.10.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.bias": "te_text_model.encoder.layers.10.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.10.mlp.c_fc.weight": "te_text_model.encoder.layers.10.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.bias": "te_text_model.encoder.layers.10.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.10.mlp.c_proj.weight": "te_text_model.encoder.layers.10.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.bias": "te_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.11.attn.out_proj.weight": "te_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.11.ln_1.bias": "te_text_model.encoder.layers.11.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.11.ln_1.weight": "te_text_model.encoder.layers.11.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.11.ln_2.bias": "te_text_model.encoder.layers.11.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.11.ln_2.weight": "te_text_model.encoder.layers.11.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.bias": "te_text_model.encoder.layers.11.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.11.mlp.c_fc.weight": "te_text_model.encoder.layers.11.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.bias": "te_text_model.encoder.layers.11.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.11.mlp.c_proj.weight": "te_text_model.encoder.layers.11.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.bias": "te_text_model.encoder.layers.12.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.12.attn.out_proj.weight": "te_text_model.encoder.layers.12.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.12.ln_1.bias": "te_text_model.encoder.layers.12.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.12.ln_1.weight": "te_text_model.encoder.layers.12.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.12.ln_2.bias": "te_text_model.encoder.layers.12.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.12.ln_2.weight": "te_text_model.encoder.layers.12.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.bias": "te_text_model.encoder.layers.12.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.12.mlp.c_fc.weight": "te_text_model.encoder.layers.12.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.bias": "te_text_model.encoder.layers.12.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.12.mlp.c_proj.weight": "te_text_model.encoder.layers.12.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.bias": "te_text_model.encoder.layers.13.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.13.attn.out_proj.weight": "te_text_model.encoder.layers.13.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.13.ln_1.bias": "te_text_model.encoder.layers.13.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.13.ln_1.weight": "te_text_model.encoder.layers.13.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.13.ln_2.bias": "te_text_model.encoder.layers.13.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.13.ln_2.weight": "te_text_model.encoder.layers.13.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.bias": "te_text_model.encoder.layers.13.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.13.mlp.c_fc.weight": "te_text_model.encoder.layers.13.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.bias": "te_text_model.encoder.layers.13.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.13.mlp.c_proj.weight": "te_text_model.encoder.layers.13.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.bias": "te_text_model.encoder.layers.14.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.14.attn.out_proj.weight": "te_text_model.encoder.layers.14.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.14.ln_1.bias": "te_text_model.encoder.layers.14.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.14.ln_1.weight": "te_text_model.encoder.layers.14.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.14.ln_2.bias": "te_text_model.encoder.layers.14.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.14.ln_2.weight": "te_text_model.encoder.layers.14.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.bias": "te_text_model.encoder.layers.14.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.14.mlp.c_fc.weight": "te_text_model.encoder.layers.14.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.bias": "te_text_model.encoder.layers.14.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.14.mlp.c_proj.weight": "te_text_model.encoder.layers.14.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.bias": "te_text_model.encoder.layers.15.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.15.attn.out_proj.weight": "te_text_model.encoder.layers.15.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.15.ln_1.bias": "te_text_model.encoder.layers.15.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.15.ln_1.weight": "te_text_model.encoder.layers.15.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.15.ln_2.bias": "te_text_model.encoder.layers.15.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.15.ln_2.weight": "te_text_model.encoder.layers.15.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.bias": "te_text_model.encoder.layers.15.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.15.mlp.c_fc.weight": "te_text_model.encoder.layers.15.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.bias": "te_text_model.encoder.layers.15.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.15.mlp.c_proj.weight": "te_text_model.encoder.layers.15.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.bias": "te_text_model.encoder.layers.16.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.16.attn.out_proj.weight": "te_text_model.encoder.layers.16.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.16.ln_1.bias": "te_text_model.encoder.layers.16.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.16.ln_1.weight": "te_text_model.encoder.layers.16.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.16.ln_2.bias": "te_text_model.encoder.layers.16.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.16.ln_2.weight": "te_text_model.encoder.layers.16.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.bias": "te_text_model.encoder.layers.16.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.16.mlp.c_fc.weight": "te_text_model.encoder.layers.16.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.bias": "te_text_model.encoder.layers.16.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.16.mlp.c_proj.weight": "te_text_model.encoder.layers.16.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.bias": "te_text_model.encoder.layers.17.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.17.attn.out_proj.weight": "te_text_model.encoder.layers.17.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.17.ln_1.bias": "te_text_model.encoder.layers.17.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.17.ln_1.weight": "te_text_model.encoder.layers.17.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.17.ln_2.bias": "te_text_model.encoder.layers.17.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.17.ln_2.weight": "te_text_model.encoder.layers.17.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.bias": "te_text_model.encoder.layers.17.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.17.mlp.c_fc.weight": "te_text_model.encoder.layers.17.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.bias": "te_text_model.encoder.layers.17.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.17.mlp.c_proj.weight": "te_text_model.encoder.layers.17.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.bias": "te_text_model.encoder.layers.18.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.18.attn.out_proj.weight": "te_text_model.encoder.layers.18.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.18.ln_1.bias": "te_text_model.encoder.layers.18.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.18.ln_1.weight": "te_text_model.encoder.layers.18.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.18.ln_2.bias": "te_text_model.encoder.layers.18.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.18.ln_2.weight": "te_text_model.encoder.layers.18.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.bias": "te_text_model.encoder.layers.18.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.18.mlp.c_fc.weight": "te_text_model.encoder.layers.18.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.bias": "te_text_model.encoder.layers.18.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.18.mlp.c_proj.weight": "te_text_model.encoder.layers.18.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.bias": "te_text_model.encoder.layers.19.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.19.attn.out_proj.weight": "te_text_model.encoder.layers.19.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.19.ln_1.bias": "te_text_model.encoder.layers.19.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.19.ln_1.weight": "te_text_model.encoder.layers.19.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.19.ln_2.bias": "te_text_model.encoder.layers.19.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.19.ln_2.weight": "te_text_model.encoder.layers.19.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.bias": "te_text_model.encoder.layers.19.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.19.mlp.c_fc.weight": "te_text_model.encoder.layers.19.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.bias": "te_text_model.encoder.layers.19.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.19.mlp.c_proj.weight": "te_text_model.encoder.layers.19.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.bias": "te_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.2.attn.out_proj.weight": "te_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.2.ln_1.bias": "te_text_model.encoder.layers.2.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.2.ln_1.weight": "te_text_model.encoder.layers.2.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.2.ln_2.bias": "te_text_model.encoder.layers.2.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.2.ln_2.weight": "te_text_model.encoder.layers.2.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.bias": "te_text_model.encoder.layers.2.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.2.mlp.c_fc.weight": "te_text_model.encoder.layers.2.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.bias": "te_text_model.encoder.layers.2.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.2.mlp.c_proj.weight": "te_text_model.encoder.layers.2.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.bias": "te_text_model.encoder.layers.20.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.20.attn.out_proj.weight": "te_text_model.encoder.layers.20.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.20.ln_1.bias": "te_text_model.encoder.layers.20.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.20.ln_1.weight": "te_text_model.encoder.layers.20.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.20.ln_2.bias": "te_text_model.encoder.layers.20.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.20.ln_2.weight": "te_text_model.encoder.layers.20.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.bias": "te_text_model.encoder.layers.20.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.20.mlp.c_fc.weight": "te_text_model.encoder.layers.20.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.bias": "te_text_model.encoder.layers.20.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.20.mlp.c_proj.weight": "te_text_model.encoder.layers.20.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.bias": "te_text_model.encoder.layers.21.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.21.attn.out_proj.weight": "te_text_model.encoder.layers.21.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.21.ln_1.bias": "te_text_model.encoder.layers.21.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.21.ln_1.weight": "te_text_model.encoder.layers.21.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.21.ln_2.bias": "te_text_model.encoder.layers.21.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.21.ln_2.weight": "te_text_model.encoder.layers.21.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.bias": "te_text_model.encoder.layers.21.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.21.mlp.c_fc.weight": "te_text_model.encoder.layers.21.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.bias": "te_text_model.encoder.layers.21.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.21.mlp.c_proj.weight": "te_text_model.encoder.layers.21.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.bias": "te_text_model.encoder.layers.22.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight": "te_text_model.encoder.layers.22.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.22.ln_1.bias": "te_text_model.encoder.layers.22.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.22.ln_1.weight": "te_text_model.encoder.layers.22.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.22.ln_2.bias": "te_text_model.encoder.layers.22.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.22.ln_2.weight": "te_text_model.encoder.layers.22.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.bias": "te_text_model.encoder.layers.22.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.22.mlp.c_fc.weight": "te_text_model.encoder.layers.22.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.bias": "te_text_model.encoder.layers.22.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.22.mlp.c_proj.weight": "te_text_model.encoder.layers.22.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.bias": "te_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.3.attn.out_proj.weight": "te_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.3.ln_1.bias": "te_text_model.encoder.layers.3.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.3.ln_1.weight": "te_text_model.encoder.layers.3.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.3.ln_2.bias": "te_text_model.encoder.layers.3.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.3.ln_2.weight": "te_text_model.encoder.layers.3.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.bias": "te_text_model.encoder.layers.3.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.3.mlp.c_fc.weight": "te_text_model.encoder.layers.3.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.bias": "te_text_model.encoder.layers.3.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.3.mlp.c_proj.weight": "te_text_model.encoder.layers.3.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.bias": "te_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.4.attn.out_proj.weight": "te_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.4.ln_1.bias": "te_text_model.encoder.layers.4.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.4.ln_1.weight": "te_text_model.encoder.layers.4.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.4.ln_2.bias": "te_text_model.encoder.layers.4.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.4.ln_2.weight": "te_text_model.encoder.layers.4.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.bias": "te_text_model.encoder.layers.4.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.4.mlp.c_fc.weight": "te_text_model.encoder.layers.4.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.bias": "te_text_model.encoder.layers.4.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.4.mlp.c_proj.weight": "te_text_model.encoder.layers.4.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.bias": "te_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.5.attn.out_proj.weight": "te_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.5.ln_1.bias": "te_text_model.encoder.layers.5.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.5.ln_1.weight": "te_text_model.encoder.layers.5.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.5.ln_2.bias": "te_text_model.encoder.layers.5.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.5.ln_2.weight": "te_text_model.encoder.layers.5.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.bias": "te_text_model.encoder.layers.5.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.5.mlp.c_fc.weight": "te_text_model.encoder.layers.5.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.bias": "te_text_model.encoder.layers.5.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.5.mlp.c_proj.weight": "te_text_model.encoder.layers.5.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.bias": "te_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.6.attn.out_proj.weight": "te_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.6.ln_1.bias": "te_text_model.encoder.layers.6.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.6.ln_1.weight": "te_text_model.encoder.layers.6.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.6.ln_2.bias": "te_text_model.encoder.layers.6.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.6.ln_2.weight": "te_text_model.encoder.layers.6.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.bias": "te_text_model.encoder.layers.6.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.6.mlp.c_fc.weight": "te_text_model.encoder.layers.6.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.bias": "te_text_model.encoder.layers.6.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.6.mlp.c_proj.weight": "te_text_model.encoder.layers.6.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.bias": "te_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.7.attn.out_proj.weight": "te_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.7.ln_1.bias": "te_text_model.encoder.layers.7.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.7.ln_1.weight": "te_text_model.encoder.layers.7.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.7.ln_2.bias": "te_text_model.encoder.layers.7.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.7.ln_2.weight": "te_text_model.encoder.layers.7.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.bias": "te_text_model.encoder.layers.7.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.7.mlp.c_fc.weight": "te_text_model.encoder.layers.7.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.bias": "te_text_model.encoder.layers.7.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.7.mlp.c_proj.weight": "te_text_model.encoder.layers.7.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.bias": "te_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.8.attn.out_proj.weight": "te_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.8.ln_1.bias": "te_text_model.encoder.layers.8.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.8.ln_1.weight": "te_text_model.encoder.layers.8.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.8.ln_2.bias": "te_text_model.encoder.layers.8.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.8.ln_2.weight": "te_text_model.encoder.layers.8.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.bias": "te_text_model.encoder.layers.8.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.8.mlp.c_fc.weight": "te_text_model.encoder.layers.8.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.bias": "te_text_model.encoder.layers.8.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.8.mlp.c_proj.weight": "te_text_model.encoder.layers.8.mlp.fc2.weight",
+ "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.bias": "te_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "cond_stage_model.model.transformer.resblocks.9.attn.out_proj.weight": "te_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "cond_stage_model.model.transformer.resblocks.9.ln_1.bias": "te_text_model.encoder.layers.9.layer_norm1.bias",
+ "cond_stage_model.model.transformer.resblocks.9.ln_1.weight": "te_text_model.encoder.layers.9.layer_norm1.weight",
+ "cond_stage_model.model.transformer.resblocks.9.ln_2.bias": "te_text_model.encoder.layers.9.layer_norm2.bias",
+ "cond_stage_model.model.transformer.resblocks.9.ln_2.weight": "te_text_model.encoder.layers.9.layer_norm2.weight",
+ "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.bias": "te_text_model.encoder.layers.9.mlp.fc1.bias",
+ "cond_stage_model.model.transformer.resblocks.9.mlp.c_fc.weight": "te_text_model.encoder.layers.9.mlp.fc1.weight",
+ "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.bias": "te_text_model.encoder.layers.9.mlp.fc2.bias",
+ "cond_stage_model.model.transformer.resblocks.9.mlp.c_proj.weight": "te_text_model.encoder.layers.9.mlp.fc2.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.1.1.norm.bias": "unet_down_blocks.0.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.1.1.norm.weight": "unet_down_blocks.0.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_in.bias": "unet_down_blocks.0.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_in.weight": "unet_down_blocks.0.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.1.1.proj_out.bias": "unet_down_blocks.0.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.1.1.proj_out.weight": "unet_down_blocks.0.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.10.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.bias": "unet_down_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.0.weight": "unet_down_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.bias": "unet_down_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.10.0.in_layers.2.weight": "unet_down_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.bias": "unet_down_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.0.weight": "unet_down_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.bias": "unet_down_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.10.0.out_layers.3.weight": "unet_down_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.bias": "unet_down_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.11.0.emb_layers.1.weight": "unet_down_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.bias": "unet_down_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.0.weight": "unet_down_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.bias": "unet_down_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.11.0.in_layers.2.weight": "unet_down_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.bias": "unet_down_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.0.weight": "unet_down_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.bias": "unet_down_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "unet_down_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.2.1.norm.bias": "unet_down_blocks.0.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.2.1.norm.weight": "unet_down_blocks.0.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_in.bias": "unet_down_blocks.0.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_in.weight": "unet_down_blocks.0.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.2.1.proj_out.bias": "unet_down_blocks.0.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.2.1.proj_out.weight": "unet_down_blocks.0.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.9.0.op.bias": "unet_down_blocks.2.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.9.0.op.weight": "unet_down_blocks.2.downsamplers.0.conv.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.10.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.bias": "unet_up_blocks.3.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.0.weight": "unet_up_blocks.3.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.bias": "unet_up_blocks.3.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.10.0.in_layers.2.weight": "unet_up_blocks.3.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.bias": "unet_up_blocks.3.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.0.weight": "unet_up_blocks.3.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.bias": "unet_up_blocks.3.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.10.0.out_layers.3.weight": "unet_up_blocks.3.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.bias": "unet_up_blocks.3.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.10.0.skip_connection.weight": "unet_up_blocks.3.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.10.1.norm.bias": "unet_up_blocks.3.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.10.1.norm.weight": "unet_up_blocks.3.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_in.bias": "unet_up_blocks.3.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_in.weight": "unet_up_blocks.3.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.10.1.proj_out.bias": "unet_up_blocks.3.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.10.1.proj_out.weight": "unet_up_blocks.3.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.11.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.bias": "unet_up_blocks.3.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.0.weight": "unet_up_blocks.3.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.bias": "unet_up_blocks.3.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.11.0.in_layers.2.weight": "unet_up_blocks.3.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.bias": "unet_up_blocks.3.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.0.weight": "unet_up_blocks.3.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.bias": "unet_up_blocks.3.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.11.0.out_layers.3.weight": "unet_up_blocks.3.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.bias": "unet_up_blocks.3.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.11.0.skip_connection.weight": "unet_up_blocks.3.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.11.1.norm.bias": "unet_up_blocks.3.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.11.1.norm.weight": "unet_up_blocks.3.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_in.bias": "unet_up_blocks.3.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_in.weight": "unet_up_blocks.3.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.11.1.proj_out.bias": "unet_up_blocks.3.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.11.1.proj_out.weight": "unet_up_blocks.3.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.1.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.6.1.norm.bias": "unet_up_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.6.1.norm.weight": "unet_up_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_in.bias": "unet_up_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_in.weight": "unet_up_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.6.1.proj_out.bias": "unet_up_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.6.1.proj_out.weight": "unet_up_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.1.norm.bias": "unet_up_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.7.1.norm.weight": "unet_up_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_in.bias": "unet_up_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_in.weight": "unet_up_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.7.1.proj_out.bias": "unet_up_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.7.1.proj_out.weight": "unet_up_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.1.norm.bias": "unet_up_blocks.2.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.8.1.norm.weight": "unet_up_blocks.2.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_in.bias": "unet_up_blocks.2.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_in.weight": "unet_up_blocks.2.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.8.1.proj_out.bias": "unet_up_blocks.2.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.8.1.proj_out.weight": "unet_up_blocks.2.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.8.2.conv.bias": "unet_up_blocks.2.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.8.2.conv.weight": "unet_up_blocks.2.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.bias": "unet_up_blocks.3.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.9.0.emb_layers.1.weight": "unet_up_blocks.3.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.bias": "unet_up_blocks.3.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.0.weight": "unet_up_blocks.3.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.bias": "unet_up_blocks.3.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.9.0.in_layers.2.weight": "unet_up_blocks.3.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.bias": "unet_up_blocks.3.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.0.weight": "unet_up_blocks.3.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.bias": "unet_up_blocks.3.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.9.0.out_layers.3.weight": "unet_up_blocks.3.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.bias": "unet_up_blocks.3.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.9.0.skip_connection.weight": "unet_up_blocks.3.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.9.1.norm.bias": "unet_up_blocks.3.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.9.1.norm.weight": "unet_up_blocks.3.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_in.bias": "unet_up_blocks.3.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_in.weight": "unet_up_blocks.3.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.9.1.proj_out.bias": "unet_up_blocks.3.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.9.1.proj_out.weight": "unet_up_blocks.3.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.0.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.0.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.1.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.1.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.10.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.10.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.11.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.11.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.12.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.12.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.12.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.12.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.12.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.12.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.13.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.13.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.13.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.13.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.13.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.13.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.14.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.14.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.14.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.14.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.14.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.14.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.15.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.15.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.15.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.15.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.15.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.15.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.16.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.16.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.16.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.16.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.16.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.16.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.17.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.17.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.17.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.17.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.17.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.17.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.18.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.18.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.18.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.18.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.18.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.18.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.19.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.19.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.19.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.19.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.19.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.19.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.2.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.2.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.20.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.20.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.20.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.20.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.20.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.20.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.21.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.21.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.21.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.21.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.21.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.21.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.22.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.22.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.22.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.22.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.22.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.22.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.3.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.3.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.4.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.4.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.5.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.5.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.6.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.6.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.7.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.7.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.8.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.8.self_attn.v_proj.weight"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias": {
+ "cat": [
+ "te_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "te_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "te_text_model.encoder.layers.9.self_attn.v_proj.bias"
+ ]
+ },
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight": {
+ "cat": [
+ "te_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "te_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "te_text_model.encoder.layers.9.self_attn.v_proj.weight"
+ ]
+ }
+ },
+ "diffusers_ldm_operator_map": {
+ "te_text_model.encoder.layers.0.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.0.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.0.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.0.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.0.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.0.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.1.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.1.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.10.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.10.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.11.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.11.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.12.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.12.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.13.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.13.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.14.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.14.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.15.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.15.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.16.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.16.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.17.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.17.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.18.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.18.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.19.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.19.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.2.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.2.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.20.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.20.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.21.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.21.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.22.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.22.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.3.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.3.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.4.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.4.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.5.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.5.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.6.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.6.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.7.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.7.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.8.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.8.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.q_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.k_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.v_proj.bias": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.q_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.k_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te_text_model.encoder.layers.9.self_attn.v_proj.weight": {
+ "slice": [
+ "cond_stage_model.model.transformer.resblocks.9.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..14d1315934ea605ae2ffdfa143dac5ba4d31788f
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sd2_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25cdb3685616f5851c554f47beb9ff9c09d0aa3d73e4263b2a94384903dea592
+size 27316630
diff --git a/toolkit/keymaps/stable_diffusion_sd2_unmatched.json b/toolkit/keymaps/stable_diffusion_sd2_unmatched.json
new file mode 100644
index 0000000000000000000000000000000000000000..3814d87e7d37f7a1bc565132baf07a269a30422c
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sd2_unmatched.json
@@ -0,0 +1,200 @@
+{
+ "ldm": {
+ "alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.00466156005859375,
+ "max": 0.9990234375
+ },
+ "alphas_cumprod_prev": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0047149658203125,
+ "max": 1.0
+ },
+ "betas": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0008502006530761719,
+ "max": 0.01200103759765625
+ },
+ "cond_stage_model.model.logit_scale": {
+ "shape": [],
+ "min": 4.60546875,
+ "max": 4.60546875
+ },
+ "cond_stage_model.model.text_projection": {
+ "shape": [
+ 1024,
+ 1024
+ ],
+ "min": -0.109130859375,
+ "max": 0.09271240234375
+ },
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias": {
+ "shape": [
+ 3072
+ ],
+ "min": -2.525390625,
+ "max": 2.591796875
+ },
+ "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight": {
+ "shape": [
+ 3072,
+ 1024
+ ],
+ "min": -0.12261962890625,
+ "max": 0.1258544921875
+ },
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias": {
+ "shape": [
+ 1024
+ ],
+ "min": -0.422607421875,
+ "max": 1.17578125
+ },
+ "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight": {
+ "shape": [
+ 1024,
+ 1024
+ ],
+ "min": -0.0738525390625,
+ "max": 0.08673095703125
+ },
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.bias": {
+ "shape": [
+ 1024
+ ],
+ "min": -3.392578125,
+ "max": 0.90625
+ },
+ "cond_stage_model.model.transformer.resblocks.23.ln_1.weight": {
+ "shape": [
+ 1024
+ ],
+ "min": 0.379638671875,
+ "max": 2.02734375
+ },
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.bias": {
+ "shape": [
+ 1024
+ ],
+ "min": -0.833984375,
+ "max": 2.525390625
+ },
+ "cond_stage_model.model.transformer.resblocks.23.ln_2.weight": {
+ "shape": [
+ 1024
+ ],
+ "min": 1.17578125,
+ "max": 2.037109375
+ },
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias": {
+ "shape": [
+ 4096
+ ],
+ "min": -1.619140625,
+ "max": 0.5595703125
+ },
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight": {
+ "shape": [
+ 4096,
+ 1024
+ ],
+ "min": -0.08953857421875,
+ "max": 0.13232421875
+ },
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias": {
+ "shape": [
+ 1024
+ ],
+ "min": -1.8662109375,
+ "max": 0.74658203125
+ },
+ "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight": {
+ "shape": [
+ 1024,
+ 4096
+ ],
+ "min": -0.12939453125,
+ "max": 0.1009521484375
+ },
+ "log_one_minus_alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": -7.0703125,
+ "max": -0.004669189453125
+ },
+ "model_ema.decay": {
+ "shape": [],
+ "min": 1.0,
+ "max": 1.0
+ },
+ "model_ema.num_updates": {
+ "shape": [],
+ "min": 219996,
+ "max": 219996
+ },
+ "posterior_log_variance_clipped": {
+ "shape": [
+ 1000
+ ],
+ "min": -46.0625,
+ "max": -4.421875
+ },
+ "posterior_mean_coef1": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.000827789306640625,
+ "max": 1.0
+ },
+ "posterior_mean_coef2": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0,
+ "max": 0.99560546875
+ },
+ "posterior_variance": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0,
+ "max": 0.01200103759765625
+ },
+ "sqrt_alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0682373046875,
+ "max": 0.99951171875
+ },
+ "sqrt_one_minus_alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0291595458984375,
+ "max": 0.99755859375
+ },
+ "sqrt_recip_alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": 1.0,
+ "max": 14.6484375
+ },
+ "sqrt_recipm1_alphas_cumprod": {
+ "shape": [
+ 1000
+ ],
+ "min": 0.0291595458984375,
+ "max": 14.6171875
+ }
+ },
+ "diffusers": {}
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_sdxl.json b/toolkit/keymaps/stable_diffusion_sdxl.json
new file mode 100644
index 0000000000000000000000000000000000000000..dd3c24475b9a933839567e20990d0910944ba82e
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sdxl.json
@@ -0,0 +1,4154 @@
+{
+ "ldm_diffusers_keymap": {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias",
+ "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.4.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.5.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.6.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.7.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.8.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.9.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.middle_block.1.norm.bias": "unet_mid_block.attentions.0.norm.bias",
+ "model.diffusion_model.middle_block.1.norm.weight": "unet_mid_block.attentions.0.norm.weight",
+ "model.diffusion_model.middle_block.1.proj_in.bias": "unet_mid_block.attentions.0.proj_in.bias",
+ "model.diffusion_model.middle_block.1.proj_in.weight": "unet_mid_block.attentions.0.proj_in.weight",
+ "model.diffusion_model.middle_block.1.proj_out.bias": "unet_mid_block.attentions.0.proj_out.bias",
+ "model.diffusion_model.middle_block.1.proj_out.weight": "unet_mid_block.attentions.0.proj_out.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.1.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.2.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.3.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.4.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.5.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.6.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.7.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.8.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn1.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_k.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_q.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.attn2.to_v.weight": "unet_mid_block.attentions.0.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.ff.net.2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm1.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm2.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.bias": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight": "unet_mid_block.attentions.0.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.middle_block.2.emb_layers.1.bias": "unet_mid_block.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.2.emb_layers.1.weight": "unet_mid_block.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.2.in_layers.0.bias": "unet_mid_block.resnets.1.norm1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.0.weight": "unet_mid_block.resnets.1.norm1.weight",
+ "model.diffusion_model.middle_block.2.in_layers.2.bias": "unet_mid_block.resnets.1.conv1.bias",
+ "model.diffusion_model.middle_block.2.in_layers.2.weight": "unet_mid_block.resnets.1.conv1.weight",
+ "model.diffusion_model.middle_block.2.out_layers.0.bias": "unet_mid_block.resnets.1.norm2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.0.weight": "unet_mid_block.resnets.1.norm2.weight",
+ "model.diffusion_model.middle_block.2.out_layers.3.bias": "unet_mid_block.resnets.1.conv2.bias",
+ "model.diffusion_model.middle_block.2.out_layers.3.weight": "unet_mid_block.resnets.1.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight"
+ ]
+ }
+ },
+ "diffusers_ldm_operator_map": {
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..16f4d21046ef187cb3dd34d83b7eaa3aea216394
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sdxl_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:243672eb340dae3396626886ae9c270ad0d212b9df970d3021037f829d8c70a5
+size 3277308
diff --git a/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json b/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json
new file mode 100644
index 0000000000000000000000000000000000000000..d0b2554ae6e6fd8bd12d660cdb64437132c8b52d
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_sdxl_unmatched.json
@@ -0,0 +1,35 @@
+{
+ "ldm": {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
+ "shape": [
+ 1,
+ 77
+ ],
+ "min": 0.0,
+ "max": 76.0
+ },
+ "conditioner.embedders.1.model.logit_scale": {
+ "shape": [],
+ "min": 4.60546875,
+ "max": 4.60546875
+ },
+ "conditioner.embedders.1.model.text_projection": {
+ "shape": [
+ 1280,
+ 1280
+ ],
+ "min": -0.15966796875,
+ "max": 0.230712890625
+ }
+ },
+ "diffusers": {
+ "te1_text_projection.weight": {
+ "shape": [
+ 1280,
+ 1280
+ ],
+ "min": -0.15966796875,
+ "max": 0.230712890625
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_ssd.json b/toolkit/keymaps/stable_diffusion_ssd.json
new file mode 100644
index 0000000000000000000000000000000000000000..9ad06407be7c6eedb4fcfa06805827bf1d2f6924
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_ssd.json
@@ -0,0 +1,3419 @@
+{
+ "ldm_diffusers_keymap": {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias",
+ "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight",
+ "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.2.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.3.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.2.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.2.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.3.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.3.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.4.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.4.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.5.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.5.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.6.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.6.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.7.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.7.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.8.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.8.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.9.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.9.norm3.weight",
+ "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight"
+ ]
+ }
+ },
+ "diffusers_ldm_operator_map": {
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "2048:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "0:1024, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "1024:2048, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "2048:, :"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3936b58486f98f1199a5cd7358fd3846ae61199d
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_ssd_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0df922c1d1dd2ff13d557ac95a8c867ba3319a87f1a3d74aeb2022a64361f914
+size 572
diff --git a/toolkit/keymaps/stable_diffusion_ssd_unmatched.json b/toolkit/keymaps/stable_diffusion_ssd_unmatched.json
new file mode 100644
index 0000000000000000000000000000000000000000..6871c9eb6af0c9a4018e1b0ead6c9fd7c7ee387b
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_ssd_unmatched.json
@@ -0,0 +1,21 @@
+{
+ "ldm": {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_ids": {
+ "shape": [
+ 1,
+ 77
+ ],
+ "min": 0.0,
+ "max": 76.0
+ },
+ "conditioner.embedders.1.model.text_model.embeddings.position_ids": {
+ "shape": [
+ 1,
+ 77
+ ],
+ "min": 0.0,
+ "max": 76.0
+ }
+ },
+ "diffusers": {}
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_vega.json b/toolkit/keymaps/stable_diffusion_vega.json
new file mode 100644
index 0000000000000000000000000000000000000000..4117c201963bea780c16acd720055699b92acf43
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_vega.json
@@ -0,0 +1,3039 @@
+{
+ "ldm_diffusers_keymap": {
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "te0_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "te0_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "te0_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "te0_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "te0_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "te0_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "te0_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "te0_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "te0_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "te0_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "te0_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "te0_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "te0_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "te0_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "te0_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "te0_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "te0_text_model.encoder.layers.0.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "te0_text_model.encoder.layers.0.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "te0_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "te0_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "te0_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "te0_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "te0_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "te0_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "te0_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "te0_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "te0_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "te0_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "te0_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "te0_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "te0_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "te0_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "te0_text_model.encoder.layers.1.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "te0_text_model.encoder.layers.1.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "te0_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "te0_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "te0_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "te0_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "te0_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "te0_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "te0_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "te0_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "te0_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "te0_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "te0_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "te0_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "te0_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "te0_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "te0_text_model.encoder.layers.10.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "te0_text_model.encoder.layers.10.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.bias": "te0_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm1.weight": "te0_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.bias": "te0_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.layer_norm2.weight": "te0_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "te0_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "te0_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "te0_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "te0_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "te0_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "te0_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "te0_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "te0_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "te0_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "te0_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "te0_text_model.encoder.layers.11.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "te0_text_model.encoder.layers.11.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "te0_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "te0_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "te0_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "te0_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "te0_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "te0_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "te0_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "te0_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "te0_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "te0_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "te0_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "te0_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "te0_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "te0_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "te0_text_model.encoder.layers.2.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "te0_text_model.encoder.layers.2.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "te0_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "te0_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "te0_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "te0_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "te0_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "te0_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "te0_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "te0_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "te0_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "te0_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "te0_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "te0_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "te0_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "te0_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "te0_text_model.encoder.layers.3.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "te0_text_model.encoder.layers.3.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "te0_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "te0_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "te0_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "te0_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "te0_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "te0_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "te0_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "te0_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "te0_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "te0_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "te0_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "te0_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "te0_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "te0_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "te0_text_model.encoder.layers.4.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "te0_text_model.encoder.layers.4.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "te0_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "te0_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "te0_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "te0_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "te0_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "te0_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "te0_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "te0_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "te0_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "te0_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "te0_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "te0_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "te0_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "te0_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "te0_text_model.encoder.layers.5.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "te0_text_model.encoder.layers.5.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "te0_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "te0_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "te0_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "te0_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "te0_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "te0_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "te0_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "te0_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "te0_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "te0_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "te0_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "te0_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "te0_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "te0_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "te0_text_model.encoder.layers.6.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "te0_text_model.encoder.layers.6.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "te0_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "te0_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "te0_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "te0_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "te0_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "te0_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "te0_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "te0_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "te0_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "te0_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "te0_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "te0_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "te0_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "te0_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "te0_text_model.encoder.layers.7.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "te0_text_model.encoder.layers.7.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "te0_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "te0_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "te0_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "te0_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "te0_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "te0_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "te0_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "te0_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "te0_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "te0_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "te0_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "te0_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "te0_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "te0_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "te0_text_model.encoder.layers.8.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "te0_text_model.encoder.layers.8.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "te0_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "te0_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "te0_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "te0_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "te0_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "te0_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "te0_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "te0_text_model.encoder.layers.9.mlp.fc2.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "te0_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "te0_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "te0_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "te0_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "te0_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "te0_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "te0_text_model.encoder.layers.9.self_attn.v_proj.bias",
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "te0_text_model.encoder.layers.9.self_attn.v_proj.weight",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.bias": "te0_text_model.final_layer_norm.bias",
+ "conditioner.embedders.0.transformer.text_model.final_layer_norm.weight": "te0_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.ln_final.bias": "te1_text_model.final_layer_norm.bias",
+ "conditioner.embedders.1.model.ln_final.weight": "te1_text_model.final_layer_norm.weight",
+ "conditioner.embedders.1.model.positional_embedding": "te1_text_model.embeddings.position_embedding.weight",
+ "conditioner.embedders.1.model.text_projection.weight": "te1_text_projection.weight",
+ "conditioner.embedders.1.model.token_embedding.weight": "te1_text_model.embeddings.token_embedding.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "te1_text_model.encoder.layers.0.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "te1_text_model.encoder.layers.0.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "te1_text_model.encoder.layers.0.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "te1_text_model.encoder.layers.0.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "te1_text_model.encoder.layers.0.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "te1_text_model.encoder.layers.0.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "te1_text_model.encoder.layers.0.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "te1_text_model.encoder.layers.0.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "te1_text_model.encoder.layers.0.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "te1_text_model.encoder.layers.0.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "te1_text_model.encoder.layers.1.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "te1_text_model.encoder.layers.1.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "te1_text_model.encoder.layers.1.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "te1_text_model.encoder.layers.1.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "te1_text_model.encoder.layers.1.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "te1_text_model.encoder.layers.1.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "te1_text_model.encoder.layers.1.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "te1_text_model.encoder.layers.1.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "te1_text_model.encoder.layers.1.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "te1_text_model.encoder.layers.1.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "te1_text_model.encoder.layers.10.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "te1_text_model.encoder.layers.10.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "te1_text_model.encoder.layers.10.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "te1_text_model.encoder.layers.10.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "te1_text_model.encoder.layers.10.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "te1_text_model.encoder.layers.10.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "te1_text_model.encoder.layers.10.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "te1_text_model.encoder.layers.10.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "te1_text_model.encoder.layers.10.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "te1_text_model.encoder.layers.10.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "te1_text_model.encoder.layers.11.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "te1_text_model.encoder.layers.11.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "te1_text_model.encoder.layers.11.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "te1_text_model.encoder.layers.11.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "te1_text_model.encoder.layers.11.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "te1_text_model.encoder.layers.11.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "te1_text_model.encoder.layers.11.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "te1_text_model.encoder.layers.11.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "te1_text_model.encoder.layers.11.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "te1_text_model.encoder.layers.11.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "te1_text_model.encoder.layers.12.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "te1_text_model.encoder.layers.12.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "te1_text_model.encoder.layers.12.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "te1_text_model.encoder.layers.12.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "te1_text_model.encoder.layers.12.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "te1_text_model.encoder.layers.12.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "te1_text_model.encoder.layers.12.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "te1_text_model.encoder.layers.12.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "te1_text_model.encoder.layers.12.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "te1_text_model.encoder.layers.12.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "te1_text_model.encoder.layers.13.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "te1_text_model.encoder.layers.13.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "te1_text_model.encoder.layers.13.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "te1_text_model.encoder.layers.13.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "te1_text_model.encoder.layers.13.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "te1_text_model.encoder.layers.13.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "te1_text_model.encoder.layers.13.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "te1_text_model.encoder.layers.13.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "te1_text_model.encoder.layers.13.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "te1_text_model.encoder.layers.13.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "te1_text_model.encoder.layers.14.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "te1_text_model.encoder.layers.14.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "te1_text_model.encoder.layers.14.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "te1_text_model.encoder.layers.14.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "te1_text_model.encoder.layers.14.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "te1_text_model.encoder.layers.14.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "te1_text_model.encoder.layers.14.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "te1_text_model.encoder.layers.14.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "te1_text_model.encoder.layers.14.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "te1_text_model.encoder.layers.14.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "te1_text_model.encoder.layers.15.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "te1_text_model.encoder.layers.15.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "te1_text_model.encoder.layers.15.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "te1_text_model.encoder.layers.15.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "te1_text_model.encoder.layers.15.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "te1_text_model.encoder.layers.15.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "te1_text_model.encoder.layers.15.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "te1_text_model.encoder.layers.15.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "te1_text_model.encoder.layers.15.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "te1_text_model.encoder.layers.15.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "te1_text_model.encoder.layers.16.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "te1_text_model.encoder.layers.16.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "te1_text_model.encoder.layers.16.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "te1_text_model.encoder.layers.16.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "te1_text_model.encoder.layers.16.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "te1_text_model.encoder.layers.16.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "te1_text_model.encoder.layers.16.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "te1_text_model.encoder.layers.16.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "te1_text_model.encoder.layers.16.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "te1_text_model.encoder.layers.16.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "te1_text_model.encoder.layers.17.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "te1_text_model.encoder.layers.17.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "te1_text_model.encoder.layers.17.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "te1_text_model.encoder.layers.17.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "te1_text_model.encoder.layers.17.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "te1_text_model.encoder.layers.17.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "te1_text_model.encoder.layers.17.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "te1_text_model.encoder.layers.17.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "te1_text_model.encoder.layers.17.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "te1_text_model.encoder.layers.17.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "te1_text_model.encoder.layers.18.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "te1_text_model.encoder.layers.18.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "te1_text_model.encoder.layers.18.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "te1_text_model.encoder.layers.18.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "te1_text_model.encoder.layers.18.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "te1_text_model.encoder.layers.18.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "te1_text_model.encoder.layers.18.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "te1_text_model.encoder.layers.18.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "te1_text_model.encoder.layers.18.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "te1_text_model.encoder.layers.18.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "te1_text_model.encoder.layers.19.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "te1_text_model.encoder.layers.19.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "te1_text_model.encoder.layers.19.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "te1_text_model.encoder.layers.19.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "te1_text_model.encoder.layers.19.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "te1_text_model.encoder.layers.19.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "te1_text_model.encoder.layers.19.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "te1_text_model.encoder.layers.19.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "te1_text_model.encoder.layers.19.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "te1_text_model.encoder.layers.19.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "te1_text_model.encoder.layers.2.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "te1_text_model.encoder.layers.2.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "te1_text_model.encoder.layers.2.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "te1_text_model.encoder.layers.2.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "te1_text_model.encoder.layers.2.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "te1_text_model.encoder.layers.2.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "te1_text_model.encoder.layers.2.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "te1_text_model.encoder.layers.2.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "te1_text_model.encoder.layers.2.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "te1_text_model.encoder.layers.2.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "te1_text_model.encoder.layers.20.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "te1_text_model.encoder.layers.20.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "te1_text_model.encoder.layers.20.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "te1_text_model.encoder.layers.20.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "te1_text_model.encoder.layers.20.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "te1_text_model.encoder.layers.20.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "te1_text_model.encoder.layers.20.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "te1_text_model.encoder.layers.20.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "te1_text_model.encoder.layers.20.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "te1_text_model.encoder.layers.20.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "te1_text_model.encoder.layers.21.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "te1_text_model.encoder.layers.21.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "te1_text_model.encoder.layers.21.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "te1_text_model.encoder.layers.21.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "te1_text_model.encoder.layers.21.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "te1_text_model.encoder.layers.21.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "te1_text_model.encoder.layers.21.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "te1_text_model.encoder.layers.21.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "te1_text_model.encoder.layers.21.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "te1_text_model.encoder.layers.21.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "te1_text_model.encoder.layers.22.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "te1_text_model.encoder.layers.22.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "te1_text_model.encoder.layers.22.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "te1_text_model.encoder.layers.22.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "te1_text_model.encoder.layers.22.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "te1_text_model.encoder.layers.22.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "te1_text_model.encoder.layers.22.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "te1_text_model.encoder.layers.22.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "te1_text_model.encoder.layers.22.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "te1_text_model.encoder.layers.22.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "te1_text_model.encoder.layers.23.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "te1_text_model.encoder.layers.23.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "te1_text_model.encoder.layers.23.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "te1_text_model.encoder.layers.23.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "te1_text_model.encoder.layers.23.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "te1_text_model.encoder.layers.23.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "te1_text_model.encoder.layers.23.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "te1_text_model.encoder.layers.23.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "te1_text_model.encoder.layers.23.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "te1_text_model.encoder.layers.23.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "te1_text_model.encoder.layers.24.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "te1_text_model.encoder.layers.24.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "te1_text_model.encoder.layers.24.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "te1_text_model.encoder.layers.24.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "te1_text_model.encoder.layers.24.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "te1_text_model.encoder.layers.24.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "te1_text_model.encoder.layers.24.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "te1_text_model.encoder.layers.24.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "te1_text_model.encoder.layers.24.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "te1_text_model.encoder.layers.24.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "te1_text_model.encoder.layers.25.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "te1_text_model.encoder.layers.25.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "te1_text_model.encoder.layers.25.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "te1_text_model.encoder.layers.25.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "te1_text_model.encoder.layers.25.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "te1_text_model.encoder.layers.25.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "te1_text_model.encoder.layers.25.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "te1_text_model.encoder.layers.25.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "te1_text_model.encoder.layers.25.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "te1_text_model.encoder.layers.25.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "te1_text_model.encoder.layers.26.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "te1_text_model.encoder.layers.26.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "te1_text_model.encoder.layers.26.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "te1_text_model.encoder.layers.26.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "te1_text_model.encoder.layers.26.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "te1_text_model.encoder.layers.26.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "te1_text_model.encoder.layers.26.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "te1_text_model.encoder.layers.26.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "te1_text_model.encoder.layers.26.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "te1_text_model.encoder.layers.26.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "te1_text_model.encoder.layers.27.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "te1_text_model.encoder.layers.27.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "te1_text_model.encoder.layers.27.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "te1_text_model.encoder.layers.27.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "te1_text_model.encoder.layers.27.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "te1_text_model.encoder.layers.27.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "te1_text_model.encoder.layers.27.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "te1_text_model.encoder.layers.27.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "te1_text_model.encoder.layers.27.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "te1_text_model.encoder.layers.27.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "te1_text_model.encoder.layers.28.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "te1_text_model.encoder.layers.28.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "te1_text_model.encoder.layers.28.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "te1_text_model.encoder.layers.28.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "te1_text_model.encoder.layers.28.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "te1_text_model.encoder.layers.28.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "te1_text_model.encoder.layers.28.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "te1_text_model.encoder.layers.28.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "te1_text_model.encoder.layers.28.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "te1_text_model.encoder.layers.28.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "te1_text_model.encoder.layers.29.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "te1_text_model.encoder.layers.29.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "te1_text_model.encoder.layers.29.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "te1_text_model.encoder.layers.29.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "te1_text_model.encoder.layers.29.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "te1_text_model.encoder.layers.29.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "te1_text_model.encoder.layers.29.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "te1_text_model.encoder.layers.29.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "te1_text_model.encoder.layers.29.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "te1_text_model.encoder.layers.29.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "te1_text_model.encoder.layers.3.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "te1_text_model.encoder.layers.3.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "te1_text_model.encoder.layers.3.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "te1_text_model.encoder.layers.3.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "te1_text_model.encoder.layers.3.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "te1_text_model.encoder.layers.3.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "te1_text_model.encoder.layers.3.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "te1_text_model.encoder.layers.3.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "te1_text_model.encoder.layers.3.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "te1_text_model.encoder.layers.3.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "te1_text_model.encoder.layers.30.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "te1_text_model.encoder.layers.30.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "te1_text_model.encoder.layers.30.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "te1_text_model.encoder.layers.30.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "te1_text_model.encoder.layers.30.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "te1_text_model.encoder.layers.30.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "te1_text_model.encoder.layers.30.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "te1_text_model.encoder.layers.30.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "te1_text_model.encoder.layers.30.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "te1_text_model.encoder.layers.30.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "te1_text_model.encoder.layers.31.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "te1_text_model.encoder.layers.31.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "te1_text_model.encoder.layers.31.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "te1_text_model.encoder.layers.31.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "te1_text_model.encoder.layers.31.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "te1_text_model.encoder.layers.31.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "te1_text_model.encoder.layers.31.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "te1_text_model.encoder.layers.31.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "te1_text_model.encoder.layers.31.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "te1_text_model.encoder.layers.31.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "te1_text_model.encoder.layers.4.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "te1_text_model.encoder.layers.4.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "te1_text_model.encoder.layers.4.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "te1_text_model.encoder.layers.4.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "te1_text_model.encoder.layers.4.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "te1_text_model.encoder.layers.4.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "te1_text_model.encoder.layers.4.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "te1_text_model.encoder.layers.4.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "te1_text_model.encoder.layers.4.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "te1_text_model.encoder.layers.4.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "te1_text_model.encoder.layers.5.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "te1_text_model.encoder.layers.5.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "te1_text_model.encoder.layers.5.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "te1_text_model.encoder.layers.5.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "te1_text_model.encoder.layers.5.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "te1_text_model.encoder.layers.5.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "te1_text_model.encoder.layers.5.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "te1_text_model.encoder.layers.5.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "te1_text_model.encoder.layers.5.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "te1_text_model.encoder.layers.5.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "te1_text_model.encoder.layers.6.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "te1_text_model.encoder.layers.6.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "te1_text_model.encoder.layers.6.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "te1_text_model.encoder.layers.6.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "te1_text_model.encoder.layers.6.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "te1_text_model.encoder.layers.6.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "te1_text_model.encoder.layers.6.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "te1_text_model.encoder.layers.6.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "te1_text_model.encoder.layers.6.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "te1_text_model.encoder.layers.6.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "te1_text_model.encoder.layers.7.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "te1_text_model.encoder.layers.7.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "te1_text_model.encoder.layers.7.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "te1_text_model.encoder.layers.7.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "te1_text_model.encoder.layers.7.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "te1_text_model.encoder.layers.7.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "te1_text_model.encoder.layers.7.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "te1_text_model.encoder.layers.7.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "te1_text_model.encoder.layers.7.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "te1_text_model.encoder.layers.7.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "te1_text_model.encoder.layers.8.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "te1_text_model.encoder.layers.8.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "te1_text_model.encoder.layers.8.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "te1_text_model.encoder.layers.8.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "te1_text_model.encoder.layers.8.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "te1_text_model.encoder.layers.8.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "te1_text_model.encoder.layers.8.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "te1_text_model.encoder.layers.8.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "te1_text_model.encoder.layers.8.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "te1_text_model.encoder.layers.8.mlp.fc2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "te1_text_model.encoder.layers.9.self_attn.out_proj.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "te1_text_model.encoder.layers.9.self_attn.out_proj.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "te1_text_model.encoder.layers.9.layer_norm1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "te1_text_model.encoder.layers.9.layer_norm1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "te1_text_model.encoder.layers.9.layer_norm2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "te1_text_model.encoder.layers.9.layer_norm2.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "te1_text_model.encoder.layers.9.mlp.fc1.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "te1_text_model.encoder.layers.9.mlp.fc1.weight",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "te1_text_model.encoder.layers.9.mlp.fc2.bias",
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "te1_text_model.encoder.layers.9.mlp.fc2.weight",
+ "first_stage_model.decoder.conv_in.bias": "vae_decoder.conv_in.bias",
+ "first_stage_model.decoder.conv_in.weight": "vae_decoder.conv_in.weight",
+ "first_stage_model.decoder.conv_out.bias": "vae_decoder.conv_out.bias",
+ "first_stage_model.decoder.conv_out.weight": "vae_decoder.conv_out.weight",
+ "first_stage_model.decoder.mid.attn_1.k.bias": "vae_decoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.decoder.mid.attn_1.k.weight": "vae_decoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "vae_decoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "vae_decoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "vae_decoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "vae_decoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.decoder.mid.attn_1.q.bias": "vae_decoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.decoder.mid.attn_1.q.weight": "vae_decoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.decoder.mid.attn_1.v.bias": "vae_decoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.decoder.mid.attn_1.v.weight": "vae_decoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "vae_decoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "vae_decoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "vae_decoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "vae_decoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "vae_decoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "vae_decoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "vae_decoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "vae_decoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "vae_decoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "vae_decoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "vae_decoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "vae_decoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "vae_decoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "vae_decoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "vae_decoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "vae_decoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.decoder.norm_out.bias": "vae_decoder.conv_norm_out.bias",
+ "first_stage_model.decoder.norm_out.weight": "vae_decoder.conv_norm_out.weight",
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "vae_decoder.up_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "vae_decoder.up_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "vae_decoder.up_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "vae_decoder.up_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "vae_decoder.up_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "vae_decoder.up_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "vae_decoder.up_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "vae_decoder.up_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "vae_decoder.up_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "vae_decoder.up_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "vae_decoder.up_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "vae_decoder.up_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "vae_decoder.up_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "vae_decoder.up_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "vae_decoder.up_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "vae_decoder.up_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "vae_decoder.up_blocks.3.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "vae_decoder.up_blocks.3.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "vae_decoder.up_blocks.3.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "vae_decoder.up_blocks.3.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "vae_decoder.up_blocks.3.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "vae_decoder.up_blocks.3.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "vae_decoder.up_blocks.3.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "vae_decoder.up_blocks.3.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "vae_decoder.up_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "vae_decoder.up_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "vae_decoder.up_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "vae_decoder.up_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "vae_decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "vae_decoder.up_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "vae_decoder.up_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "vae_decoder.up_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "vae_decoder.up_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "vae_decoder.up_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "vae_decoder.up_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "vae_decoder.up_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "vae_decoder.up_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "vae_decoder.up_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "vae_decoder.up_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "vae_decoder.up_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "vae_decoder.up_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "vae_decoder.up_blocks.2.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "vae_decoder.up_blocks.2.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "vae_decoder.up_blocks.2.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "vae_decoder.up_blocks.2.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "vae_decoder.up_blocks.2.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "vae_decoder.up_blocks.2.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "vae_decoder.up_blocks.2.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "vae_decoder.up_blocks.2.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "vae_decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "vae_decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "vae_decoder.up_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "vae_decoder.up_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "vae_decoder.up_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "vae_decoder.up_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "vae_decoder.up_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "vae_decoder.up_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "vae_decoder.up_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "vae_decoder.up_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "vae_decoder.up_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "vae_decoder.up_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "vae_decoder.up_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "vae_decoder.up_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "vae_decoder.up_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "vae_decoder.up_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "vae_decoder.up_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "vae_decoder.up_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "vae_decoder.up_blocks.1.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "vae_decoder.up_blocks.1.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "vae_decoder.up_blocks.1.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "vae_decoder.up_blocks.1.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "vae_decoder.up_blocks.1.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "vae_decoder.up_blocks.1.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "vae_decoder.up_blocks.1.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "vae_decoder.up_blocks.1.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "vae_decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "vae_decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "vae_decoder.up_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "vae_decoder.up_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "vae_decoder.up_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "vae_decoder.up_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "vae_decoder.up_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "vae_decoder.up_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "vae_decoder.up_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "vae_decoder.up_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "vae_decoder.up_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "vae_decoder.up_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "vae_decoder.up_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "vae_decoder.up_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "vae_decoder.up_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "vae_decoder.up_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "vae_decoder.up_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "vae_decoder.up_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "vae_decoder.up_blocks.0.resnets.2.conv1.bias",
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "vae_decoder.up_blocks.0.resnets.2.conv1.weight",
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "vae_decoder.up_blocks.0.resnets.2.conv2.bias",
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "vae_decoder.up_blocks.0.resnets.2.conv2.weight",
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "vae_decoder.up_blocks.0.resnets.2.norm1.bias",
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "vae_decoder.up_blocks.0.resnets.2.norm1.weight",
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "vae_decoder.up_blocks.0.resnets.2.norm2.bias",
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "vae_decoder.up_blocks.0.resnets.2.norm2.weight",
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "vae_decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "vae_decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "first_stage_model.encoder.conv_in.bias": "vae_encoder.conv_in.bias",
+ "first_stage_model.encoder.conv_in.weight": "vae_encoder.conv_in.weight",
+ "first_stage_model.encoder.conv_out.bias": "vae_encoder.conv_out.bias",
+ "first_stage_model.encoder.conv_out.weight": "vae_encoder.conv_out.weight",
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "vae_encoder.down_blocks.0.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "vae_encoder.down_blocks.0.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "vae_encoder.down_blocks.0.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "vae_encoder.down_blocks.0.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "vae_encoder.down_blocks.0.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "vae_encoder.down_blocks.0.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "vae_encoder.down_blocks.0.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "vae_encoder.down_blocks.0.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "vae_encoder.down_blocks.0.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "vae_encoder.down_blocks.0.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "vae_encoder.down_blocks.0.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "vae_encoder.down_blocks.0.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "vae_encoder.down_blocks.0.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "vae_encoder.down_blocks.0.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "vae_encoder.down_blocks.0.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "vae_encoder.down_blocks.0.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "vae_encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "vae_encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "vae_encoder.down_blocks.1.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "vae_encoder.down_blocks.1.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "vae_encoder.down_blocks.1.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "vae_encoder.down_blocks.1.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "vae_encoder.down_blocks.1.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "vae_encoder.down_blocks.1.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "vae_encoder.down_blocks.1.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "vae_encoder.down_blocks.1.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "vae_encoder.down_blocks.1.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "vae_encoder.down_blocks.1.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "vae_encoder.down_blocks.1.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "vae_encoder.down_blocks.1.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "vae_encoder.down_blocks.1.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "vae_encoder.down_blocks.1.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "vae_encoder.down_blocks.1.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "vae_encoder.down_blocks.1.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "vae_encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "vae_encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "vae_encoder.down_blocks.2.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "vae_encoder.down_blocks.2.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "vae_encoder.down_blocks.2.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "vae_encoder.down_blocks.2.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "vae_encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "vae_encoder.down_blocks.2.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "vae_encoder.down_blocks.2.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "vae_encoder.down_blocks.2.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "vae_encoder.down_blocks.2.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "vae_encoder.down_blocks.2.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "vae_encoder.down_blocks.2.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "vae_encoder.down_blocks.2.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "vae_encoder.down_blocks.2.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "vae_encoder.down_blocks.2.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "vae_encoder.down_blocks.2.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "vae_encoder.down_blocks.2.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "vae_encoder.down_blocks.2.resnets.1.norm2.weight",
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "vae_encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "vae_encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "vae_encoder.down_blocks.3.resnets.0.conv1.bias",
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "vae_encoder.down_blocks.3.resnets.0.conv1.weight",
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "vae_encoder.down_blocks.3.resnets.0.conv2.bias",
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "vae_encoder.down_blocks.3.resnets.0.conv2.weight",
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "vae_encoder.down_blocks.3.resnets.0.norm1.bias",
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "vae_encoder.down_blocks.3.resnets.0.norm1.weight",
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "vae_encoder.down_blocks.3.resnets.0.norm2.bias",
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "vae_encoder.down_blocks.3.resnets.0.norm2.weight",
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "vae_encoder.down_blocks.3.resnets.1.conv1.bias",
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "vae_encoder.down_blocks.3.resnets.1.conv1.weight",
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "vae_encoder.down_blocks.3.resnets.1.conv2.bias",
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "vae_encoder.down_blocks.3.resnets.1.conv2.weight",
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "vae_encoder.down_blocks.3.resnets.1.norm1.bias",
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "vae_encoder.down_blocks.3.resnets.1.norm1.weight",
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "vae_encoder.down_blocks.3.resnets.1.norm2.bias",
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "vae_encoder.down_blocks.3.resnets.1.norm2.weight",
+ "first_stage_model.encoder.mid.attn_1.k.bias": "vae_encoder.mid_block.attentions.0.to_k.bias",
+ "first_stage_model.encoder.mid.attn_1.k.weight": "vae_encoder.mid_block.attentions.0.to_k.weight",
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "vae_encoder.mid_block.attentions.0.group_norm.bias",
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "vae_encoder.mid_block.attentions.0.group_norm.weight",
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "vae_encoder.mid_block.attentions.0.to_out.0.bias",
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "vae_encoder.mid_block.attentions.0.to_out.0.weight",
+ "first_stage_model.encoder.mid.attn_1.q.bias": "vae_encoder.mid_block.attentions.0.to_q.bias",
+ "first_stage_model.encoder.mid.attn_1.q.weight": "vae_encoder.mid_block.attentions.0.to_q.weight",
+ "first_stage_model.encoder.mid.attn_1.v.bias": "vae_encoder.mid_block.attentions.0.to_v.bias",
+ "first_stage_model.encoder.mid.attn_1.v.weight": "vae_encoder.mid_block.attentions.0.to_v.weight",
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "vae_encoder.mid_block.resnets.0.conv1.bias",
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "vae_encoder.mid_block.resnets.0.conv1.weight",
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "vae_encoder.mid_block.resnets.0.conv2.bias",
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "vae_encoder.mid_block.resnets.0.conv2.weight",
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "vae_encoder.mid_block.resnets.0.norm1.bias",
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "vae_encoder.mid_block.resnets.0.norm1.weight",
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "vae_encoder.mid_block.resnets.0.norm2.bias",
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "vae_encoder.mid_block.resnets.0.norm2.weight",
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "vae_encoder.mid_block.resnets.1.conv1.bias",
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "vae_encoder.mid_block.resnets.1.conv1.weight",
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "vae_encoder.mid_block.resnets.1.conv2.bias",
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "vae_encoder.mid_block.resnets.1.conv2.weight",
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "vae_encoder.mid_block.resnets.1.norm1.bias",
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "vae_encoder.mid_block.resnets.1.norm1.weight",
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "vae_encoder.mid_block.resnets.1.norm2.bias",
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "vae_encoder.mid_block.resnets.1.norm2.weight",
+ "first_stage_model.encoder.norm_out.bias": "vae_encoder.conv_norm_out.bias",
+ "first_stage_model.encoder.norm_out.weight": "vae_encoder.conv_norm_out.weight",
+ "first_stage_model.post_quant_conv.bias": "vae_post_quant_conv.bias",
+ "first_stage_model.post_quant_conv.weight": "vae_post_quant_conv.weight",
+ "first_stage_model.quant_conv.bias": "vae_quant_conv.bias",
+ "first_stage_model.quant_conv.weight": "vae_quant_conv.weight",
+ "model.diffusion_model.input_blocks.0.0.bias": "unet_conv_in.bias",
+ "model.diffusion_model.input_blocks.0.0.weight": "unet_conv_in.weight",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.1.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.bias": "unet_down_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.0.weight": "unet_down_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.bias": "unet_down_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.1.0.in_layers.2.weight": "unet_down_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.bias": "unet_down_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.0.weight": "unet_down_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.bias": "unet_down_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.1.0.out_layers.3.weight": "unet_down_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.bias": "unet_down_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.2.0.emb_layers.1.weight": "unet_down_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.bias": "unet_down_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.0.weight": "unet_down_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.bias": "unet_down_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.2.0.in_layers.2.weight": "unet_down_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.bias": "unet_down_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.0.weight": "unet_down_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.bias": "unet_down_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.2.0.out_layers.3.weight": "unet_down_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.3.0.op.bias": "unet_down_blocks.0.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.3.0.op.weight": "unet_down_blocks.0.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.4.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.bias": "unet_down_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.0.weight": "unet_down_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.bias": "unet_down_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.4.0.in_layers.2.weight": "unet_down_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.bias": "unet_down_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.0.weight": "unet_down_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.bias": "unet_down_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.4.0.out_layers.3.weight": "unet_down_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.bias": "unet_down_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.4.0.skip_connection.weight": "unet_down_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.4.1.norm.bias": "unet_down_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.4.1.norm.weight": "unet_down_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_in.bias": "unet_down_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_in.weight": "unet_down_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.4.1.proj_out.bias": "unet_down_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.4.1.proj_out.weight": "unet_down_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.bias": "unet_down_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.5.0.emb_layers.1.weight": "unet_down_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.bias": "unet_down_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.0.weight": "unet_down_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.bias": "unet_down_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.5.0.in_layers.2.weight": "unet_down_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.bias": "unet_down_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.0.weight": "unet_down_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.bias": "unet_down_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.5.0.out_layers.3.weight": "unet_down_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.5.1.norm.bias": "unet_down_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.5.1.norm.weight": "unet_down_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_in.bias": "unet_down_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_in.weight": "unet_down_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.5.1.proj_out.bias": "unet_down_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.5.1.proj_out.weight": "unet_down_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.6.0.op.bias": "unet_down_blocks.1.downsamplers.0.conv.bias",
+ "model.diffusion_model.input_blocks.6.0.op.weight": "unet_down_blocks.1.downsamplers.0.conv.weight",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.7.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.bias": "unet_down_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.0.weight": "unet_down_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.bias": "unet_down_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.input_blocks.7.0.in_layers.2.weight": "unet_down_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.bias": "unet_down_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.0.weight": "unet_down_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.bias": "unet_down_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.input_blocks.7.0.out_layers.3.weight": "unet_down_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.bias": "unet_down_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.input_blocks.7.0.skip_connection.weight": "unet_down_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.input_blocks.7.1.norm.bias": "unet_down_blocks.2.attentions.0.norm.bias",
+ "model.diffusion_model.input_blocks.7.1.norm.weight": "unet_down_blocks.2.attentions.0.norm.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_in.bias": "unet_down_blocks.2.attentions.0.proj_in.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_in.weight": "unet_down_blocks.2.attentions.0.proj_in.weight",
+ "model.diffusion_model.input_blocks.7.1.proj_out.bias": "unet_down_blocks.2.attentions.0.proj_out.bias",
+ "model.diffusion_model.input_blocks.7.1.proj_out.weight": "unet_down_blocks.2.attentions.0.proj_out.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.7.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.bias": "unet_down_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.input_blocks.8.0.emb_layers.1.weight": "unet_down_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.bias": "unet_down_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.0.weight": "unet_down_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.bias": "unet_down_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.input_blocks.8.0.in_layers.2.weight": "unet_down_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.bias": "unet_down_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.0.weight": "unet_down_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.bias": "unet_down_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.input_blocks.8.0.out_layers.3.weight": "unet_down_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.input_blocks.8.1.norm.bias": "unet_down_blocks.2.attentions.1.norm.bias",
+ "model.diffusion_model.input_blocks.8.1.norm.weight": "unet_down_blocks.2.attentions.1.norm.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_in.bias": "unet_down_blocks.2.attentions.1.proj_in.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_in.weight": "unet_down_blocks.2.attentions.1.proj_in.weight",
+ "model.diffusion_model.input_blocks.8.1.proj_out.bias": "unet_down_blocks.2.attentions.1.proj_out.bias",
+ "model.diffusion_model.input_blocks.8.1.proj_out.weight": "unet_down_blocks.2.attentions.1.proj_out.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn1.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_k.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_q.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.attn2.to_v.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.ff.net.2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm1.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm2.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.bias": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.input_blocks.8.1.transformer_blocks.1.norm3.weight": "unet_down_blocks.2.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.label_emb.0.0.bias": "unet_add_embedding.linear_1.bias",
+ "model.diffusion_model.label_emb.0.0.weight": "unet_add_embedding.linear_1.weight",
+ "model.diffusion_model.label_emb.0.2.bias": "unet_add_embedding.linear_2.bias",
+ "model.diffusion_model.label_emb.0.2.weight": "unet_add_embedding.linear_2.weight",
+ "model.diffusion_model.middle_block.0.emb_layers.1.bias": "unet_mid_block.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.middle_block.0.emb_layers.1.weight": "unet_mid_block.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.middle_block.0.in_layers.0.bias": "unet_mid_block.resnets.0.norm1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.0.weight": "unet_mid_block.resnets.0.norm1.weight",
+ "model.diffusion_model.middle_block.0.in_layers.2.bias": "unet_mid_block.resnets.0.conv1.bias",
+ "model.diffusion_model.middle_block.0.in_layers.2.weight": "unet_mid_block.resnets.0.conv1.weight",
+ "model.diffusion_model.middle_block.0.out_layers.0.bias": "unet_mid_block.resnets.0.norm2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.0.weight": "unet_mid_block.resnets.0.norm2.weight",
+ "model.diffusion_model.middle_block.0.out_layers.3.bias": "unet_mid_block.resnets.0.conv2.bias",
+ "model.diffusion_model.middle_block.0.out_layers.3.weight": "unet_mid_block.resnets.0.conv2.weight",
+ "model.diffusion_model.out.0.bias": "unet_conv_norm_out.bias",
+ "model.diffusion_model.out.0.weight": "unet_conv_norm_out.weight",
+ "model.diffusion_model.out.2.bias": "unet_conv_out.bias",
+ "model.diffusion_model.out.2.weight": "unet_conv_out.weight",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.0.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.bias": "unet_up_blocks.0.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.0.weight": "unet_up_blocks.0.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.bias": "unet_up_blocks.0.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.0.0.in_layers.2.weight": "unet_up_blocks.0.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.bias": "unet_up_blocks.0.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.0.weight": "unet_up_blocks.0.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.bias": "unet_up_blocks.0.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.0.0.out_layers.3.weight": "unet_up_blocks.0.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.bias": "unet_up_blocks.0.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.0.0.skip_connection.weight": "unet_up_blocks.0.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.0.1.norm.bias": "unet_up_blocks.0.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.0.1.norm.weight": "unet_up_blocks.0.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_in.bias": "unet_up_blocks.0.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_in.weight": "unet_up_blocks.0.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.0.1.proj_out.bias": "unet_up_blocks.0.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.0.1.proj_out.weight": "unet_up_blocks.0.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.0.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.0.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.1.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.bias": "unet_up_blocks.0.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.0.weight": "unet_up_blocks.0.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.bias": "unet_up_blocks.0.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.1.0.in_layers.2.weight": "unet_up_blocks.0.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.bias": "unet_up_blocks.0.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.0.weight": "unet_up_blocks.0.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.bias": "unet_up_blocks.0.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.1.0.out_layers.3.weight": "unet_up_blocks.0.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.bias": "unet_up_blocks.0.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.1.0.skip_connection.weight": "unet_up_blocks.0.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "unet_up_blocks.0.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.1.1.norm.weight": "unet_up_blocks.0.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_in.bias": "unet_up_blocks.0.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_in.weight": "unet_up_blocks.0.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.1.1.proj_out.bias": "unet_up_blocks.0.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.1.1.proj_out.weight": "unet_up_blocks.0.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.1.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.1.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.bias": "unet_up_blocks.0.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.2.0.emb_layers.1.weight": "unet_up_blocks.0.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.bias": "unet_up_blocks.0.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.0.weight": "unet_up_blocks.0.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.bias": "unet_up_blocks.0.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.2.0.in_layers.2.weight": "unet_up_blocks.0.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.bias": "unet_up_blocks.0.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.0.weight": "unet_up_blocks.0.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.bias": "unet_up_blocks.0.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.2.0.out_layers.3.weight": "unet_up_blocks.0.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.bias": "unet_up_blocks.0.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.2.0.skip_connection.weight": "unet_up_blocks.0.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.2.1.norm.bias": "unet_up_blocks.0.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.2.1.norm.weight": "unet_up_blocks.0.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_in.bias": "unet_up_blocks.0.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_in.weight": "unet_up_blocks.0.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.2.1.proj_out.bias": "unet_up_blocks.0.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.2.1.proj_out.weight": "unet_up_blocks.0.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn1.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_k.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_out.0.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_q.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.attn2.to_v.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.0.proj.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.ff.net.2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm1.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm1.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm2.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm2.weight",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.bias": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.bias",
+ "model.diffusion_model.output_blocks.2.1.transformer_blocks.1.norm3.weight": "unet_up_blocks.0.attentions.2.transformer_blocks.1.norm3.weight",
+ "model.diffusion_model.output_blocks.2.2.conv.bias": "unet_up_blocks.0.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.2.2.conv.weight": "unet_up_blocks.0.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.3.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.bias": "unet_up_blocks.1.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.0.weight": "unet_up_blocks.1.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.bias": "unet_up_blocks.1.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.3.0.in_layers.2.weight": "unet_up_blocks.1.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.bias": "unet_up_blocks.1.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.0.weight": "unet_up_blocks.1.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.bias": "unet_up_blocks.1.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.3.0.out_layers.3.weight": "unet_up_blocks.1.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.bias": "unet_up_blocks.1.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.3.0.skip_connection.weight": "unet_up_blocks.1.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.3.1.norm.bias": "unet_up_blocks.1.attentions.0.norm.bias",
+ "model.diffusion_model.output_blocks.3.1.norm.weight": "unet_up_blocks.1.attentions.0.norm.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_in.bias": "unet_up_blocks.1.attentions.0.proj_in.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_in.weight": "unet_up_blocks.1.attentions.0.proj_in.weight",
+ "model.diffusion_model.output_blocks.3.1.proj_out.bias": "unet_up_blocks.1.attentions.0.proj_out.bias",
+ "model.diffusion_model.output_blocks.3.1.proj_out.weight": "unet_up_blocks.1.attentions.0.proj_out.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.0.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.4.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.bias": "unet_up_blocks.1.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.0.weight": "unet_up_blocks.1.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.bias": "unet_up_blocks.1.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.4.0.in_layers.2.weight": "unet_up_blocks.1.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.bias": "unet_up_blocks.1.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.0.weight": "unet_up_blocks.1.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.bias": "unet_up_blocks.1.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.4.0.out_layers.3.weight": "unet_up_blocks.1.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.bias": "unet_up_blocks.1.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.4.0.skip_connection.weight": "unet_up_blocks.1.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.4.1.norm.bias": "unet_up_blocks.1.attentions.1.norm.bias",
+ "model.diffusion_model.output_blocks.4.1.norm.weight": "unet_up_blocks.1.attentions.1.norm.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_in.bias": "unet_up_blocks.1.attentions.1.proj_in.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_in.weight": "unet_up_blocks.1.attentions.1.proj_in.weight",
+ "model.diffusion_model.output_blocks.4.1.proj_out.bias": "unet_up_blocks.1.attentions.1.proj_out.bias",
+ "model.diffusion_model.output_blocks.4.1.proj_out.weight": "unet_up_blocks.1.attentions.1.proj_out.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.1.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.bias": "unet_up_blocks.1.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.5.0.emb_layers.1.weight": "unet_up_blocks.1.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.bias": "unet_up_blocks.1.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.0.weight": "unet_up_blocks.1.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.bias": "unet_up_blocks.1.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.5.0.in_layers.2.weight": "unet_up_blocks.1.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.bias": "unet_up_blocks.1.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.0.weight": "unet_up_blocks.1.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.bias": "unet_up_blocks.1.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.5.0.out_layers.3.weight": "unet_up_blocks.1.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.bias": "unet_up_blocks.1.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.5.0.skip_connection.weight": "unet_up_blocks.1.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.5.1.norm.bias": "unet_up_blocks.1.attentions.2.norm.bias",
+ "model.diffusion_model.output_blocks.5.1.norm.weight": "unet_up_blocks.1.attentions.2.norm.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_in.bias": "unet_up_blocks.1.attentions.2.proj_in.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_in.weight": "unet_up_blocks.1.attentions.2.proj_in.weight",
+ "model.diffusion_model.output_blocks.5.1.proj_out.bias": "unet_up_blocks.1.attentions.2.proj_out.bias",
+ "model.diffusion_model.output_blocks.5.1.proj_out.weight": "unet_up_blocks.1.attentions.2.proj_out.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm1.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm2.weight",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.bias": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.bias",
+ "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3.weight": "unet_up_blocks.1.attentions.2.transformer_blocks.0.norm3.weight",
+ "model.diffusion_model.output_blocks.5.2.conv.bias": "unet_up_blocks.1.upsamplers.0.conv.bias",
+ "model.diffusion_model.output_blocks.5.2.conv.weight": "unet_up_blocks.1.upsamplers.0.conv.weight",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.0.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.6.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.0.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.bias": "unet_up_blocks.2.resnets.0.norm1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.0.weight": "unet_up_blocks.2.resnets.0.norm1.weight",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.bias": "unet_up_blocks.2.resnets.0.conv1.bias",
+ "model.diffusion_model.output_blocks.6.0.in_layers.2.weight": "unet_up_blocks.2.resnets.0.conv1.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.bias": "unet_up_blocks.2.resnets.0.norm2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.0.weight": "unet_up_blocks.2.resnets.0.norm2.weight",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.bias": "unet_up_blocks.2.resnets.0.conv2.bias",
+ "model.diffusion_model.output_blocks.6.0.out_layers.3.weight": "unet_up_blocks.2.resnets.0.conv2.weight",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.bias": "unet_up_blocks.2.resnets.0.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.6.0.skip_connection.weight": "unet_up_blocks.2.resnets.0.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.1.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.7.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.1.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.bias": "unet_up_blocks.2.resnets.1.norm1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.0.weight": "unet_up_blocks.2.resnets.1.norm1.weight",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.bias": "unet_up_blocks.2.resnets.1.conv1.bias",
+ "model.diffusion_model.output_blocks.7.0.in_layers.2.weight": "unet_up_blocks.2.resnets.1.conv1.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.bias": "unet_up_blocks.2.resnets.1.norm2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.0.weight": "unet_up_blocks.2.resnets.1.norm2.weight",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.bias": "unet_up_blocks.2.resnets.1.conv2.bias",
+ "model.diffusion_model.output_blocks.7.0.out_layers.3.weight": "unet_up_blocks.2.resnets.1.conv2.weight",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.bias": "unet_up_blocks.2.resnets.1.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.7.0.skip_connection.weight": "unet_up_blocks.2.resnets.1.conv_shortcut.weight",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.bias": "unet_up_blocks.2.resnets.2.time_emb_proj.bias",
+ "model.diffusion_model.output_blocks.8.0.emb_layers.1.weight": "unet_up_blocks.2.resnets.2.time_emb_proj.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.bias": "unet_up_blocks.2.resnets.2.norm1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.0.weight": "unet_up_blocks.2.resnets.2.norm1.weight",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.bias": "unet_up_blocks.2.resnets.2.conv1.bias",
+ "model.diffusion_model.output_blocks.8.0.in_layers.2.weight": "unet_up_blocks.2.resnets.2.conv1.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.bias": "unet_up_blocks.2.resnets.2.norm2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.0.weight": "unet_up_blocks.2.resnets.2.norm2.weight",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.bias": "unet_up_blocks.2.resnets.2.conv2.bias",
+ "model.diffusion_model.output_blocks.8.0.out_layers.3.weight": "unet_up_blocks.2.resnets.2.conv2.weight",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.bias": "unet_up_blocks.2.resnets.2.conv_shortcut.bias",
+ "model.diffusion_model.output_blocks.8.0.skip_connection.weight": "unet_up_blocks.2.resnets.2.conv_shortcut.weight",
+ "model.diffusion_model.time_embed.0.bias": "unet_time_embedding.linear_1.bias",
+ "model.diffusion_model.time_embed.0.weight": "unet_time_embedding.linear_1.weight",
+ "model.diffusion_model.time_embed.2.bias": "unet_time_embedding.linear_2.bias",
+ "model.diffusion_model.time_embed.2.weight": "unet_time_embedding.linear_2.weight"
+ },
+ "ldm_diffusers_shape_map": {
+ "first_stage_model.decoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.decoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.k.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.q.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ],
+ "first_stage_model.encoder.mid.attn_1.v.weight": [
+ [
+ 512,
+ 512,
+ 1,
+ 1
+ ],
+ [
+ 512,
+ 512
+ ]
+ ]
+ },
+ "ldm_diffusers_operator_map": {
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias"
+ ]
+ },
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": {
+ "cat": [
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight",
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight"
+ ]
+ }
+ },
+ "diffusers_ldm_operator_map": {
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.0.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.1.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.10.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.11.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.12.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.13.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.14.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.15.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.16.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.17.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.18.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.19.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.2.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.20.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.21.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.22.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.23.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.24.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.25.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.26.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.27.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.28.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.29.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.3.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.30.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.31.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.4.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.5.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.6.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.7.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.8.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.bias": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias",
+ "2560:, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.q_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "0:1280, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.k_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "1280:2560, :"
+ ]
+ },
+ "te1_text_model.encoder.layers.9.self_attn.v_proj.weight": {
+ "slice": [
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight",
+ "2560:, :"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors b/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..8e2c4cb90b8d10d6c9a844a3b73ef3e07541f130
--- /dev/null
+++ b/toolkit/keymaps/stable_diffusion_vega_ldm_base.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9bbcbf73561f6bc5d0a17ea6a2081feed2d1304e87602d8c502d9a5c4bd85576
+size 16
diff --git a/toolkit/kohya_model_util.py b/toolkit/kohya_model_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..798fc2dccb5787973cdf2bbab769459f22b3a805
--- /dev/null
+++ b/toolkit/kohya_model_util.py
@@ -0,0 +1,1533 @@
+# mostly from https://github.com/kohya-ss/sd-scripts/blob/main/library/model_util.py
+# I am infinitely grateful to @kohya-ss for their amazing work in this field.
+# This version is updated to handle the latest version of the diffusers library.
+import json
+# v1: split from train_db_fixed.py.
+# v2: support safetensors
+
+import math
+import os
+import re
+
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
+from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from safetensors.torch import load_file, save_file
+from collections import OrderedDict
+
+# DiffUsers版StableDiffusionのモデルパラメータ
+NUM_TRAIN_TIMESTEPS = 1000
+BETA_START = 0.00085
+BETA_END = 0.0120
+
+UNET_PARAMS_MODEL_CHANNELS = 320
+UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
+UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
+UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
+UNET_PARAMS_IN_CHANNELS = 4
+UNET_PARAMS_OUT_CHANNELS = 4
+UNET_PARAMS_NUM_RES_BLOCKS = 2
+UNET_PARAMS_CONTEXT_DIM = 768
+UNET_PARAMS_NUM_HEADS = 8
+# UNET_PARAMS_USE_LINEAR_PROJECTION = False
+
+VAE_PARAMS_Z_CHANNELS = 4
+VAE_PARAMS_RESOLUTION = 256
+VAE_PARAMS_IN_CHANNELS = 3
+VAE_PARAMS_OUT_CH = 3
+VAE_PARAMS_CH = 128
+VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
+VAE_PARAMS_NUM_RES_BLOCKS = 2
+
+# V2
+V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
+V2_UNET_PARAMS_CONTEXT_DIM = 1024
+# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
+
+# Diffusersの設定を読み込むための参照モデル
+DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
+DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
+
+
+# region StableDiffusion->Diffusersの変換コード
+# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # updated for latest diffusers
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "to_q.weight")
+ new_item = new_item.replace("q.bias", "to_q.bias")
+
+ new_item = new_item.replace("k.weight", "to_k.weight")
+ new_item = new_item.replace("k.bias", "to_k.bias")
+
+ new_item = new_item.replace("v.weight", "to_v.weight")
+ new_item = new_item.replace("v.bias", "to_v.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming
+ to them. It splits attention layers, and takes into account additional replacements
+ that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
+ shape = old_checkpoint[path["old"]].shape
+ if is_attn_weight and len(shape) == 3:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ elif is_attn_weight and len(shape) == 4:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def linear_transformer_to_conv(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim == 2:
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
+
+
+def convert_ldm_unet_checkpoint(v2, checkpoint, config):
+ mapping = {}
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in
+ range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in
+ range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in
+ range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [key for key in input_blocks[i] if
+ f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ mapping[f'input_blocks.{i}.0.op.weight'] = f"down_blocks.{block_id}.downsamplers.0.conv.weight"
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias")
+ mapping[f'input_blocks.{i}.0.op.bias'] = f"down_blocks.{block_id}.downsamplers.0.conv.bias"
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
+ config=config)
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
+ config=config)
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
+ config=config)
+
+ # オリジナル:
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
+
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
+ for l in output_block_list.values():
+ l.sort()
+
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path],
+ config=config)
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ # SDのv2では1*1のconv2dがlinearに変わっている
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
+ if v2 and not config.get('use_linear_projection', False):
+ linear_transformer_to_conv(new_checkpoint)
+
+ # print("mapping: ", json.dumps(mapping, indent=4))
+ return new_checkpoint
+
+
+# ldm key: diffusers key
+vae_ldm_to_diffusers_dict = {
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
+ "decoder.mid.attn_1.k.bias": "decoder.mid_block.attentions.0.to_k.bias",
+ "decoder.mid.attn_1.k.weight": "decoder.mid_block.attentions.0.to_k.weight",
+ "decoder.mid.attn_1.norm.bias": "decoder.mid_block.attentions.0.group_norm.bias",
+ "decoder.mid.attn_1.norm.weight": "decoder.mid_block.attentions.0.group_norm.weight",
+ "decoder.mid.attn_1.proj_out.bias": "decoder.mid_block.attentions.0.to_out.0.bias",
+ "decoder.mid.attn_1.proj_out.weight": "decoder.mid_block.attentions.0.to_out.0.weight",
+ "decoder.mid.attn_1.q.bias": "decoder.mid_block.attentions.0.to_q.bias",
+ "decoder.mid.attn_1.q.weight": "decoder.mid_block.attentions.0.to_q.weight",
+ "decoder.mid.attn_1.v.bias": "decoder.mid_block.attentions.0.to_v.bias",
+ "decoder.mid.attn_1.v.weight": "decoder.mid_block.attentions.0.to_v.weight",
+ "decoder.mid.block_1.conv1.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.mid.block_1.conv1.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.mid.block_1.conv2.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.mid.block_1.conv2.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.mid.block_1.norm1.bias": "decoder.mid_block.resnets.0.norm1.bias",
+ "decoder.mid.block_1.norm1.weight": "decoder.mid_block.resnets.0.norm1.weight",
+ "decoder.mid.block_1.norm2.bias": "decoder.mid_block.resnets.0.norm2.bias",
+ "decoder.mid.block_1.norm2.weight": "decoder.mid_block.resnets.0.norm2.weight",
+ "decoder.mid.block_2.conv1.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.mid.block_2.conv1.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.mid.block_2.conv2.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.mid.block_2.conv2.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ "decoder.mid.block_2.norm1.bias": "decoder.mid_block.resnets.1.norm1.bias",
+ "decoder.mid.block_2.norm1.weight": "decoder.mid_block.resnets.1.norm1.weight",
+ "decoder.mid.block_2.norm2.bias": "decoder.mid_block.resnets.1.norm2.bias",
+ "decoder.mid.block_2.norm2.weight": "decoder.mid_block.resnets.1.norm2.weight",
+ "decoder.norm_out.bias": "decoder.conv_norm_out.bias",
+ "decoder.norm_out.weight": "decoder.conv_norm_out.weight",
+ "decoder.up.0.block.0.conv1.bias": "decoder.up_blocks.3.resnets.0.conv1.bias",
+ "decoder.up.0.block.0.conv1.weight": "decoder.up_blocks.3.resnets.0.conv1.weight",
+ "decoder.up.0.block.0.conv2.bias": "decoder.up_blocks.3.resnets.0.conv2.bias",
+ "decoder.up.0.block.0.conv2.weight": "decoder.up_blocks.3.resnets.0.conv2.weight",
+ "decoder.up.0.block.0.nin_shortcut.bias": "decoder.up_blocks.3.resnets.0.conv_shortcut.bias",
+ "decoder.up.0.block.0.nin_shortcut.weight": "decoder.up_blocks.3.resnets.0.conv_shortcut.weight",
+ "decoder.up.0.block.0.norm1.bias": "decoder.up_blocks.3.resnets.0.norm1.bias",
+ "decoder.up.0.block.0.norm1.weight": "decoder.up_blocks.3.resnets.0.norm1.weight",
+ "decoder.up.0.block.0.norm2.bias": "decoder.up_blocks.3.resnets.0.norm2.bias",
+ "decoder.up.0.block.0.norm2.weight": "decoder.up_blocks.3.resnets.0.norm2.weight",
+ "decoder.up.0.block.1.conv1.bias": "decoder.up_blocks.3.resnets.1.conv1.bias",
+ "decoder.up.0.block.1.conv1.weight": "decoder.up_blocks.3.resnets.1.conv1.weight",
+ "decoder.up.0.block.1.conv2.bias": "decoder.up_blocks.3.resnets.1.conv2.bias",
+ "decoder.up.0.block.1.conv2.weight": "decoder.up_blocks.3.resnets.1.conv2.weight",
+ "decoder.up.0.block.1.norm1.bias": "decoder.up_blocks.3.resnets.1.norm1.bias",
+ "decoder.up.0.block.1.norm1.weight": "decoder.up_blocks.3.resnets.1.norm1.weight",
+ "decoder.up.0.block.1.norm2.bias": "decoder.up_blocks.3.resnets.1.norm2.bias",
+ "decoder.up.0.block.1.norm2.weight": "decoder.up_blocks.3.resnets.1.norm2.weight",
+ "decoder.up.0.block.2.conv1.bias": "decoder.up_blocks.3.resnets.2.conv1.bias",
+ "decoder.up.0.block.2.conv1.weight": "decoder.up_blocks.3.resnets.2.conv1.weight",
+ "decoder.up.0.block.2.conv2.bias": "decoder.up_blocks.3.resnets.2.conv2.bias",
+ "decoder.up.0.block.2.conv2.weight": "decoder.up_blocks.3.resnets.2.conv2.weight",
+ "decoder.up.0.block.2.norm1.bias": "decoder.up_blocks.3.resnets.2.norm1.bias",
+ "decoder.up.0.block.2.norm1.weight": "decoder.up_blocks.3.resnets.2.norm1.weight",
+ "decoder.up.0.block.2.norm2.bias": "decoder.up_blocks.3.resnets.2.norm2.bias",
+ "decoder.up.0.block.2.norm2.weight": "decoder.up_blocks.3.resnets.2.norm2.weight",
+ "decoder.up.1.block.0.conv1.bias": "decoder.up_blocks.2.resnets.0.conv1.bias",
+ "decoder.up.1.block.0.conv1.weight": "decoder.up_blocks.2.resnets.0.conv1.weight",
+ "decoder.up.1.block.0.conv2.bias": "decoder.up_blocks.2.resnets.0.conv2.bias",
+ "decoder.up.1.block.0.conv2.weight": "decoder.up_blocks.2.resnets.0.conv2.weight",
+ "decoder.up.1.block.0.nin_shortcut.bias": "decoder.up_blocks.2.resnets.0.conv_shortcut.bias",
+ "decoder.up.1.block.0.nin_shortcut.weight": "decoder.up_blocks.2.resnets.0.conv_shortcut.weight",
+ "decoder.up.1.block.0.norm1.bias": "decoder.up_blocks.2.resnets.0.norm1.bias",
+ "decoder.up.1.block.0.norm1.weight": "decoder.up_blocks.2.resnets.0.norm1.weight",
+ "decoder.up.1.block.0.norm2.bias": "decoder.up_blocks.2.resnets.0.norm2.bias",
+ "decoder.up.1.block.0.norm2.weight": "decoder.up_blocks.2.resnets.0.norm2.weight",
+ "decoder.up.1.block.1.conv1.bias": "decoder.up_blocks.2.resnets.1.conv1.bias",
+ "decoder.up.1.block.1.conv1.weight": "decoder.up_blocks.2.resnets.1.conv1.weight",
+ "decoder.up.1.block.1.conv2.bias": "decoder.up_blocks.2.resnets.1.conv2.bias",
+ "decoder.up.1.block.1.conv2.weight": "decoder.up_blocks.2.resnets.1.conv2.weight",
+ "decoder.up.1.block.1.norm1.bias": "decoder.up_blocks.2.resnets.1.norm1.bias",
+ "decoder.up.1.block.1.norm1.weight": "decoder.up_blocks.2.resnets.1.norm1.weight",
+ "decoder.up.1.block.1.norm2.bias": "decoder.up_blocks.2.resnets.1.norm2.bias",
+ "decoder.up.1.block.1.norm2.weight": "decoder.up_blocks.2.resnets.1.norm2.weight",
+ "decoder.up.1.block.2.conv1.bias": "decoder.up_blocks.2.resnets.2.conv1.bias",
+ "decoder.up.1.block.2.conv1.weight": "decoder.up_blocks.2.resnets.2.conv1.weight",
+ "decoder.up.1.block.2.conv2.bias": "decoder.up_blocks.2.resnets.2.conv2.bias",
+ "decoder.up.1.block.2.conv2.weight": "decoder.up_blocks.2.resnets.2.conv2.weight",
+ "decoder.up.1.block.2.norm1.bias": "decoder.up_blocks.2.resnets.2.norm1.bias",
+ "decoder.up.1.block.2.norm1.weight": "decoder.up_blocks.2.resnets.2.norm1.weight",
+ "decoder.up.1.block.2.norm2.bias": "decoder.up_blocks.2.resnets.2.norm2.bias",
+ "decoder.up.1.block.2.norm2.weight": "decoder.up_blocks.2.resnets.2.norm2.weight",
+ "decoder.up.1.upsample.conv.bias": "decoder.up_blocks.2.upsamplers.0.conv.bias",
+ "decoder.up.1.upsample.conv.weight": "decoder.up_blocks.2.upsamplers.0.conv.weight",
+ "decoder.up.2.block.0.conv1.bias": "decoder.up_blocks.1.resnets.0.conv1.bias",
+ "decoder.up.2.block.0.conv1.weight": "decoder.up_blocks.1.resnets.0.conv1.weight",
+ "decoder.up.2.block.0.conv2.bias": "decoder.up_blocks.1.resnets.0.conv2.bias",
+ "decoder.up.2.block.0.conv2.weight": "decoder.up_blocks.1.resnets.0.conv2.weight",
+ "decoder.up.2.block.0.norm1.bias": "decoder.up_blocks.1.resnets.0.norm1.bias",
+ "decoder.up.2.block.0.norm1.weight": "decoder.up_blocks.1.resnets.0.norm1.weight",
+ "decoder.up.2.block.0.norm2.bias": "decoder.up_blocks.1.resnets.0.norm2.bias",
+ "decoder.up.2.block.0.norm2.weight": "decoder.up_blocks.1.resnets.0.norm2.weight",
+ "decoder.up.2.block.1.conv1.bias": "decoder.up_blocks.1.resnets.1.conv1.bias",
+ "decoder.up.2.block.1.conv1.weight": "decoder.up_blocks.1.resnets.1.conv1.weight",
+ "decoder.up.2.block.1.conv2.bias": "decoder.up_blocks.1.resnets.1.conv2.bias",
+ "decoder.up.2.block.1.conv2.weight": "decoder.up_blocks.1.resnets.1.conv2.weight",
+ "decoder.up.2.block.1.norm1.bias": "decoder.up_blocks.1.resnets.1.norm1.bias",
+ "decoder.up.2.block.1.norm1.weight": "decoder.up_blocks.1.resnets.1.norm1.weight",
+ "decoder.up.2.block.1.norm2.bias": "decoder.up_blocks.1.resnets.1.norm2.bias",
+ "decoder.up.2.block.1.norm2.weight": "decoder.up_blocks.1.resnets.1.norm2.weight",
+ "decoder.up.2.block.2.conv1.bias": "decoder.up_blocks.1.resnets.2.conv1.bias",
+ "decoder.up.2.block.2.conv1.weight": "decoder.up_blocks.1.resnets.2.conv1.weight",
+ "decoder.up.2.block.2.conv2.bias": "decoder.up_blocks.1.resnets.2.conv2.bias",
+ "decoder.up.2.block.2.conv2.weight": "decoder.up_blocks.1.resnets.2.conv2.weight",
+ "decoder.up.2.block.2.norm1.bias": "decoder.up_blocks.1.resnets.2.norm1.bias",
+ "decoder.up.2.block.2.norm1.weight": "decoder.up_blocks.1.resnets.2.norm1.weight",
+ "decoder.up.2.block.2.norm2.bias": "decoder.up_blocks.1.resnets.2.norm2.bias",
+ "decoder.up.2.block.2.norm2.weight": "decoder.up_blocks.1.resnets.2.norm2.weight",
+ "decoder.up.2.upsample.conv.bias": "decoder.up_blocks.1.upsamplers.0.conv.bias",
+ "decoder.up.2.upsample.conv.weight": "decoder.up_blocks.1.upsamplers.0.conv.weight",
+ "decoder.up.3.block.0.conv1.bias": "decoder.up_blocks.0.resnets.0.conv1.bias",
+ "decoder.up.3.block.0.conv1.weight": "decoder.up_blocks.0.resnets.0.conv1.weight",
+ "decoder.up.3.block.0.conv2.bias": "decoder.up_blocks.0.resnets.0.conv2.bias",
+ "decoder.up.3.block.0.conv2.weight": "decoder.up_blocks.0.resnets.0.conv2.weight",
+ "decoder.up.3.block.0.norm1.bias": "decoder.up_blocks.0.resnets.0.norm1.bias",
+ "decoder.up.3.block.0.norm1.weight": "decoder.up_blocks.0.resnets.0.norm1.weight",
+ "decoder.up.3.block.0.norm2.bias": "decoder.up_blocks.0.resnets.0.norm2.bias",
+ "decoder.up.3.block.0.norm2.weight": "decoder.up_blocks.0.resnets.0.norm2.weight",
+ "decoder.up.3.block.1.conv1.bias": "decoder.up_blocks.0.resnets.1.conv1.bias",
+ "decoder.up.3.block.1.conv1.weight": "decoder.up_blocks.0.resnets.1.conv1.weight",
+ "decoder.up.3.block.1.conv2.bias": "decoder.up_blocks.0.resnets.1.conv2.bias",
+ "decoder.up.3.block.1.conv2.weight": "decoder.up_blocks.0.resnets.1.conv2.weight",
+ "decoder.up.3.block.1.norm1.bias": "decoder.up_blocks.0.resnets.1.norm1.bias",
+ "decoder.up.3.block.1.norm1.weight": "decoder.up_blocks.0.resnets.1.norm1.weight",
+ "decoder.up.3.block.1.norm2.bias": "decoder.up_blocks.0.resnets.1.norm2.bias",
+ "decoder.up.3.block.1.norm2.weight": "decoder.up_blocks.0.resnets.1.norm2.weight",
+ "decoder.up.3.block.2.conv1.bias": "decoder.up_blocks.0.resnets.2.conv1.bias",
+ "decoder.up.3.block.2.conv1.weight": "decoder.up_blocks.0.resnets.2.conv1.weight",
+ "decoder.up.3.block.2.conv2.bias": "decoder.up_blocks.0.resnets.2.conv2.bias",
+ "decoder.up.3.block.2.conv2.weight": "decoder.up_blocks.0.resnets.2.conv2.weight",
+ "decoder.up.3.block.2.norm1.bias": "decoder.up_blocks.0.resnets.2.norm1.bias",
+ "decoder.up.3.block.2.norm1.weight": "decoder.up_blocks.0.resnets.2.norm1.weight",
+ "decoder.up.3.block.2.norm2.bias": "decoder.up_blocks.0.resnets.2.norm2.bias",
+ "decoder.up.3.block.2.norm2.weight": "decoder.up_blocks.0.resnets.2.norm2.weight",
+ "decoder.up.3.upsample.conv.bias": "decoder.up_blocks.0.upsamplers.0.conv.bias",
+ "decoder.up.3.upsample.conv.weight": "decoder.up_blocks.0.upsamplers.0.conv.weight",
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
+ "encoder.down.0.block.0.conv1.bias": "encoder.down_blocks.0.resnets.0.conv1.bias",
+ "encoder.down.0.block.0.conv1.weight": "encoder.down_blocks.0.resnets.0.conv1.weight",
+ "encoder.down.0.block.0.conv2.bias": "encoder.down_blocks.0.resnets.0.conv2.bias",
+ "encoder.down.0.block.0.conv2.weight": "encoder.down_blocks.0.resnets.0.conv2.weight",
+ "encoder.down.0.block.0.norm1.bias": "encoder.down_blocks.0.resnets.0.norm1.bias",
+ "encoder.down.0.block.0.norm1.weight": "encoder.down_blocks.0.resnets.0.norm1.weight",
+ "encoder.down.0.block.0.norm2.bias": "encoder.down_blocks.0.resnets.0.norm2.bias",
+ "encoder.down.0.block.0.norm2.weight": "encoder.down_blocks.0.resnets.0.norm2.weight",
+ "encoder.down.0.block.1.conv1.bias": "encoder.down_blocks.0.resnets.1.conv1.bias",
+ "encoder.down.0.block.1.conv1.weight": "encoder.down_blocks.0.resnets.1.conv1.weight",
+ "encoder.down.0.block.1.conv2.bias": "encoder.down_blocks.0.resnets.1.conv2.bias",
+ "encoder.down.0.block.1.conv2.weight": "encoder.down_blocks.0.resnets.1.conv2.weight",
+ "encoder.down.0.block.1.norm1.bias": "encoder.down_blocks.0.resnets.1.norm1.bias",
+ "encoder.down.0.block.1.norm1.weight": "encoder.down_blocks.0.resnets.1.norm1.weight",
+ "encoder.down.0.block.1.norm2.bias": "encoder.down_blocks.0.resnets.1.norm2.bias",
+ "encoder.down.0.block.1.norm2.weight": "encoder.down_blocks.0.resnets.1.norm2.weight",
+ "encoder.down.0.downsample.conv.bias": "encoder.down_blocks.0.downsamplers.0.conv.bias",
+ "encoder.down.0.downsample.conv.weight": "encoder.down_blocks.0.downsamplers.0.conv.weight",
+ "encoder.down.1.block.0.conv1.bias": "encoder.down_blocks.1.resnets.0.conv1.bias",
+ "encoder.down.1.block.0.conv1.weight": "encoder.down_blocks.1.resnets.0.conv1.weight",
+ "encoder.down.1.block.0.conv2.bias": "encoder.down_blocks.1.resnets.0.conv2.bias",
+ "encoder.down.1.block.0.conv2.weight": "encoder.down_blocks.1.resnets.0.conv2.weight",
+ "encoder.down.1.block.0.nin_shortcut.bias": "encoder.down_blocks.1.resnets.0.conv_shortcut.bias",
+ "encoder.down.1.block.0.nin_shortcut.weight": "encoder.down_blocks.1.resnets.0.conv_shortcut.weight",
+ "encoder.down.1.block.0.norm1.bias": "encoder.down_blocks.1.resnets.0.norm1.bias",
+ "encoder.down.1.block.0.norm1.weight": "encoder.down_blocks.1.resnets.0.norm1.weight",
+ "encoder.down.1.block.0.norm2.bias": "encoder.down_blocks.1.resnets.0.norm2.bias",
+ "encoder.down.1.block.0.norm2.weight": "encoder.down_blocks.1.resnets.0.norm2.weight",
+ "encoder.down.1.block.1.conv1.bias": "encoder.down_blocks.1.resnets.1.conv1.bias",
+ "encoder.down.1.block.1.conv1.weight": "encoder.down_blocks.1.resnets.1.conv1.weight",
+ "encoder.down.1.block.1.conv2.bias": "encoder.down_blocks.1.resnets.1.conv2.bias",
+ "encoder.down.1.block.1.conv2.weight": "encoder.down_blocks.1.resnets.1.conv2.weight",
+ "encoder.down.1.block.1.norm1.bias": "encoder.down_blocks.1.resnets.1.norm1.bias",
+ "encoder.down.1.block.1.norm1.weight": "encoder.down_blocks.1.resnets.1.norm1.weight",
+ "encoder.down.1.block.1.norm2.bias": "encoder.down_blocks.1.resnets.1.norm2.bias",
+ "encoder.down.1.block.1.norm2.weight": "encoder.down_blocks.1.resnets.1.norm2.weight",
+ "encoder.down.1.downsample.conv.bias": "encoder.down_blocks.1.downsamplers.0.conv.bias",
+ "encoder.down.1.downsample.conv.weight": "encoder.down_blocks.1.downsamplers.0.conv.weight",
+ "encoder.down.2.block.0.conv1.bias": "encoder.down_blocks.2.resnets.0.conv1.bias",
+ "encoder.down.2.block.0.conv1.weight": "encoder.down_blocks.2.resnets.0.conv1.weight",
+ "encoder.down.2.block.0.conv2.bias": "encoder.down_blocks.2.resnets.0.conv2.bias",
+ "encoder.down.2.block.0.conv2.weight": "encoder.down_blocks.2.resnets.0.conv2.weight",
+ "encoder.down.2.block.0.nin_shortcut.bias": "encoder.down_blocks.2.resnets.0.conv_shortcut.bias",
+ "encoder.down.2.block.0.nin_shortcut.weight": "encoder.down_blocks.2.resnets.0.conv_shortcut.weight",
+ "encoder.down.2.block.0.norm1.bias": "encoder.down_blocks.2.resnets.0.norm1.bias",
+ "encoder.down.2.block.0.norm1.weight": "encoder.down_blocks.2.resnets.0.norm1.weight",
+ "encoder.down.2.block.0.norm2.bias": "encoder.down_blocks.2.resnets.0.norm2.bias",
+ "encoder.down.2.block.0.norm2.weight": "encoder.down_blocks.2.resnets.0.norm2.weight",
+ "encoder.down.2.block.1.conv1.bias": "encoder.down_blocks.2.resnets.1.conv1.bias",
+ "encoder.down.2.block.1.conv1.weight": "encoder.down_blocks.2.resnets.1.conv1.weight",
+ "encoder.down.2.block.1.conv2.bias": "encoder.down_blocks.2.resnets.1.conv2.bias",
+ "encoder.down.2.block.1.conv2.weight": "encoder.down_blocks.2.resnets.1.conv2.weight",
+ "encoder.down.2.block.1.norm1.bias": "encoder.down_blocks.2.resnets.1.norm1.bias",
+ "encoder.down.2.block.1.norm1.weight": "encoder.down_blocks.2.resnets.1.norm1.weight",
+ "encoder.down.2.block.1.norm2.bias": "encoder.down_blocks.2.resnets.1.norm2.bias",
+ "encoder.down.2.block.1.norm2.weight": "encoder.down_blocks.2.resnets.1.norm2.weight",
+ "encoder.down.2.downsample.conv.bias": "encoder.down_blocks.2.downsamplers.0.conv.bias",
+ "encoder.down.2.downsample.conv.weight": "encoder.down_blocks.2.downsamplers.0.conv.weight",
+ "encoder.down.3.block.0.conv1.bias": "encoder.down_blocks.3.resnets.0.conv1.bias",
+ "encoder.down.3.block.0.conv1.weight": "encoder.down_blocks.3.resnets.0.conv1.weight",
+ "encoder.down.3.block.0.conv2.bias": "encoder.down_blocks.3.resnets.0.conv2.bias",
+ "encoder.down.3.block.0.conv2.weight": "encoder.down_blocks.3.resnets.0.conv2.weight",
+ "encoder.down.3.block.0.norm1.bias": "encoder.down_blocks.3.resnets.0.norm1.bias",
+ "encoder.down.3.block.0.norm1.weight": "encoder.down_blocks.3.resnets.0.norm1.weight",
+ "encoder.down.3.block.0.norm2.bias": "encoder.down_blocks.3.resnets.0.norm2.bias",
+ "encoder.down.3.block.0.norm2.weight": "encoder.down_blocks.3.resnets.0.norm2.weight",
+ "encoder.down.3.block.1.conv1.bias": "encoder.down_blocks.3.resnets.1.conv1.bias",
+ "encoder.down.3.block.1.conv1.weight": "encoder.down_blocks.3.resnets.1.conv1.weight",
+ "encoder.down.3.block.1.conv2.bias": "encoder.down_blocks.3.resnets.1.conv2.bias",
+ "encoder.down.3.block.1.conv2.weight": "encoder.down_blocks.3.resnets.1.conv2.weight",
+ "encoder.down.3.block.1.norm1.bias": "encoder.down_blocks.3.resnets.1.norm1.bias",
+ "encoder.down.3.block.1.norm1.weight": "encoder.down_blocks.3.resnets.1.norm1.weight",
+ "encoder.down.3.block.1.norm2.bias": "encoder.down_blocks.3.resnets.1.norm2.bias",
+ "encoder.down.3.block.1.norm2.weight": "encoder.down_blocks.3.resnets.1.norm2.weight",
+ "encoder.mid.attn_1.k.bias": "encoder.mid_block.attentions.0.to_k.bias",
+ "encoder.mid.attn_1.k.weight": "encoder.mid_block.attentions.0.to_k.weight",
+ "encoder.mid.attn_1.norm.bias": "encoder.mid_block.attentions.0.group_norm.bias",
+ "encoder.mid.attn_1.norm.weight": "encoder.mid_block.attentions.0.group_norm.weight",
+ "encoder.mid.attn_1.proj_out.bias": "encoder.mid_block.attentions.0.to_out.0.bias",
+ "encoder.mid.attn_1.proj_out.weight": "encoder.mid_block.attentions.0.to_out.0.weight",
+ "encoder.mid.attn_1.q.bias": "encoder.mid_block.attentions.0.to_q.bias",
+ "encoder.mid.attn_1.q.weight": "encoder.mid_block.attentions.0.to_q.weight",
+ "encoder.mid.attn_1.v.bias": "encoder.mid_block.attentions.0.to_v.bias",
+ "encoder.mid.attn_1.v.weight": "encoder.mid_block.attentions.0.to_v.weight",
+ "encoder.mid.block_1.conv1.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.mid.block_1.conv1.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.mid.block_1.conv2.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.mid.block_1.conv2.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.mid.block_1.norm1.bias": "encoder.mid_block.resnets.0.norm1.bias",
+ "encoder.mid.block_1.norm1.weight": "encoder.mid_block.resnets.0.norm1.weight",
+ "encoder.mid.block_1.norm2.bias": "encoder.mid_block.resnets.0.norm2.bias",
+ "encoder.mid.block_1.norm2.weight": "encoder.mid_block.resnets.0.norm2.weight",
+ "encoder.mid.block_2.conv1.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.mid.block_2.conv1.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.mid.block_2.conv2.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.mid.block_2.conv2.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ "encoder.mid.block_2.norm1.bias": "encoder.mid_block.resnets.1.norm1.bias",
+ "encoder.mid.block_2.norm1.weight": "encoder.mid_block.resnets.1.norm1.weight",
+ "encoder.mid.block_2.norm2.bias": "encoder.mid_block.resnets.1.norm2.bias",
+ "encoder.mid.block_2.norm2.weight": "encoder.mid_block.resnets.1.norm2.weight",
+ "encoder.norm_out.bias": "encoder.conv_norm_out.bias",
+ "encoder.norm_out.weight": "encoder.conv_norm_out.weight",
+ "post_quant_conv.bias": "post_quant_conv.bias",
+ "post_quant_conv.weight": "post_quant_conv.weight",
+ "quant_conv.bias": "quant_conv.bias",
+ "quant_conv.weight": "quant_conv.weight"
+}
+
+
+def get_diffusers_vae_key_from_ldm_key(target_ldm_key, i=None):
+ for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
+ if i is not None:
+ ldm_key = ldm_key.replace("{i}", str(i))
+ diffusers_key = diffusers_key.replace("{i}", str(i))
+ if ldm_key == target_ldm_key:
+ return diffusers_key
+
+ if ldm_key in vae_ldm_to_diffusers_dict:
+ return vae_ldm_to_diffusers_dict[ldm_key]
+ else:
+ return None
+
+# def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
+# for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
+# if diffusers_key == target_diffusers_key:
+# return ldm_key
+# return None
+
+def get_ldm_vae_key_from_diffusers_key(target_diffusers_key):
+ for ldm_key, diffusers_key in vae_ldm_to_diffusers_dict.items():
+ if "{" in diffusers_key: # if we have a placeholder
+ # escape special characters in the key, and replace the placeholder with a regex group
+ pattern = re.escape(diffusers_key).replace("\\{i\\}", "(\\d+)")
+ match = re.match(pattern, target_diffusers_key)
+ if match: # if we found a match
+ return ldm_key.format(i=match.group(1))
+ elif diffusers_key == target_diffusers_key:
+ return ldm_key
+ return None
+
+
+vae_keys_squished_on_diffusers = [
+ "decoder.mid_block.attentions.0.to_k.weight",
+ "decoder.mid_block.attentions.0.to_out.0.weight",
+ "decoder.mid_block.attentions.0.to_q.weight",
+ "decoder.mid_block.attentions.0.to_v.weight",
+ "encoder.mid_block.attentions.0.to_k.weight",
+ "encoder.mid_block.attentions.0.to_out.0.weight",
+ "encoder.mid_block.attentions.0.to_q.weight",
+ "encoder.mid_block.attentions.0.to_v.weight"
+]
+
+def convert_diffusers_back_to_ldm(diffusers_vae):
+ new_state_dict = OrderedDict()
+ diffusers_state_dict = diffusers_vae.state_dict()
+ for key, value in diffusers_state_dict.items():
+ val_to_save = value
+ if key in vae_keys_squished_on_diffusers:
+ val_to_save = value.clone()
+ # (512, 512) diffusers and (512, 512, 1, 1) ldm
+ val_to_save = val_to_save.unsqueeze(-1).unsqueeze(-1)
+ ldm_key = get_ldm_vae_key_from_diffusers_key(key)
+ if ldm_key is not None:
+ new_state_dict[ldm_key] = val_to_save
+ else:
+ # for now add current key
+ new_state_dict[key] = val_to_save
+ return new_state_dict
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ mapping = {}
+ # extract state dict for VAE
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+ # if len(vae_state_dict) == 0:
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
+ # vae_state_dict = checkpoint
+
+ new_checkpoint = {}
+
+ # for key in list(vae_state_dict.keys()):
+ # diffusers_key = get_diffusers_vae_key_from_ldm_key(key)
+ # if diffusers_key is not None:
+ # new_checkpoint[diffusers_key] = vae_state_dict[key]
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in
+ range(num_down_blocks)}
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in
+ range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ mapping[f"encoder.down.{i}.downsample.conv.weight"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+ mapping[f"encoder.down.{i}.downsample.conv.bias"] = f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [key for key in up_blocks[block_id] if
+ f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ mapping[f"decoder.up.{block_id}.upsample.conv.weight"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+ mapping[f"decoder.up.{block_id}.upsample.conv.bias"] = f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
+
+
+def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # unet_params = original_config.model.params.unet_config.params
+
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ config = dict(
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
+ in_channels=UNET_PARAMS_IN_CHANNELS,
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
+ )
+ if v2 and use_linear_projection_in_v2:
+ config["use_linear_projection"] = True
+
+ return config
+
+
+def create_vae_diffusers_config():
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = dict(
+ sample_size=VAE_PARAMS_RESOLUTION,
+ in_channels=VAE_PARAMS_IN_CHANNELS,
+ out_channels=VAE_PARAMS_OUT_CH,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
+ )
+ return config
+
+
+def convert_ldm_clip_checkpoint_v1(checkpoint):
+ keys = list(checkpoint.keys())
+ text_model_dict = {}
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
+ # support checkpoint without position_ids (invalid checkpoint)
+ if "text_model.embeddings.position_ids" not in text_model_dict:
+ text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
+ return text_model_dict
+
+
+def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
+ # 嫌になるくらい違うぞ!
+ def convert_key(key):
+ if not key.startswith("cond_stage_model"):
+ return None
+
+ # common conversion
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
+ key = key.replace("cond_stage_model.model.", "text_model.")
+
+ if "resblocks" in key:
+ # resblocks conversion
+ key = key.replace(".resblocks.", ".layers.")
+ if ".ln_" in key:
+ key = key.replace(".ln_", ".layer_norm")
+ elif ".mlp." in key:
+ key = key.replace(".c_fc.", ".fc1.")
+ key = key.replace(".c_proj.", ".fc2.")
+ elif ".attn.out_proj" in key:
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
+ elif ".attn.in_proj" in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in SD: {key}")
+ elif ".positional_embedding" in key:
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
+ elif ".text_projection" in key:
+ key = None # 使われない???
+ elif ".logit_scale" in key:
+ key = None # 使われない???
+ elif ".token_embedding" in key:
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
+ elif ".ln_final" in key:
+ key = key.replace(".ln_final", ".final_layer_norm")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ # remove resblocks 23
+ if ".resblocks.23." in key:
+ continue
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if ".resblocks.23." in key:
+ continue
+ if ".resblocks" in key and ".attn.in_proj_" in key:
+ # 三つに分割
+ values = torch.chunk(checkpoint[key], 3)
+
+ key_suffix = ".weight" if "weight" in key else ".bias"
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
+ key_pfx = key_pfx.replace("_weight", "")
+ key_pfx = key_pfx.replace("_bias", "")
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
+
+ # rename or add position_ids
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
+ # waifu diffusion v1.4
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
+ else:
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
+
+ new_sd["text_model.embeddings.position_ids"] = position_ids
+ return new_sd
+
+
+# endregion
+
+
+# region Diffusers->StableDiffusion の変換コード
+# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
+
+
+def conv_transformer_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+
+
+def convert_unet_state_dict_to_sd(v2, unet_state_dict):
+ unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ ]
+
+ unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+ ]
+
+ unet_conversion_map_layer = []
+ for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ hf_mid_atn_prefix = "mid_block.attentions.0."
+ sd_mid_atn_prefix = "middle_block.1."
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+ for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+
+ if v2:
+ conv_transformer_to_linear(new_state_dict)
+
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+ ]
+
+ for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+ # this part accounts for mid blocks in both the encoder and the decoder
+ for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+ ]
+
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ # print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+
+ return new_state_dict
+
+
+# endregion
+
+# region 自作のモデル読み書きなど
+
+
+def is_safetensors(path):
+ return os.path.splitext(path)[1].lower() == ".safetensors"
+
+
+def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
+ ]
+
+ if is_safetensors(ckpt_path):
+ checkpoint = None
+ state_dict = load_file(ckpt_path) # , device) # may causes error
+ else:
+ checkpoint = torch.load(ckpt_path, map_location=device)
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ checkpoint = None
+
+ key_reps = []
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
+ for key in state_dict.keys():
+ if key.startswith(rep_from):
+ new_key = rep_to + key[len(rep_from):]
+ key_reps.append((key, new_key))
+
+ for key, new_key in key_reps:
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+
+ return checkpoint, state_dict
+
+
+# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
+def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None,
+ unet_use_linear_projection_in_v2=False):
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
+
+ # Convert the UNet2DConditionModel model.
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
+
+ unet = UNet2DConditionModel(**unet_config).to(device)
+ info = unet.load_state_dict(converted_unet_checkpoint)
+ print("loading u-net:", info)
+
+ # Convert the VAE model.
+ vae_config = create_vae_diffusers_config()
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
+
+ vae = AutoencoderKL(**vae_config).to(device)
+ info = vae.load_state_dict(converted_vae_checkpoint)
+ print("loading vae:", info)
+
+ # convert text_model
+ if v2:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
+ cfg = CLIPTextConfig(
+ vocab_size=49408,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=23,
+ num_attention_heads=16,
+ max_position_embeddings=77,
+ hidden_act="gelu",
+ layer_norm_eps=1e-05,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ model_type="clip_text_model",
+ projection_dim=512,
+ torch_dtype="float32",
+ transformers_version="4.25.0.dev0",
+ )
+ text_model = CLIPTextModel._from_config(cfg)
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ else:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
+
+ logging.set_verbosity_error() # don't show annoying warning
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
+ logging.set_verbosity_warning()
+
+ # latest transformers doesnt have position ids. Do we remove it?
+ if "text_model.embeddings.position_ids" not in text_model.state_dict():
+ del converted_text_encoder_checkpoint["text_model.embeddings.position_ids"]
+
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ print("loading text encoder:", info)
+
+ return text_model, vae, unet
+
+
+def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
+ def convert_key(key):
+ # position_idsの除去
+ if ".position_ids" in key:
+ return None
+
+ # common
+ key = key.replace("text_model.encoder.", "transformer.")
+ key = key.replace("text_model.", "")
+ if "layers" in key:
+ # resblocks conversion
+ key = key.replace(".layers.", ".resblocks.")
+ if ".layer_norm" in key:
+ key = key.replace(".layer_norm", ".ln_")
+ elif ".mlp." in key:
+ key = key.replace(".fc1.", ".c_fc.")
+ key = key.replace(".fc2.", ".c_proj.")
+ elif ".self_attn.out_proj" in key:
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
+ elif ".self_attn." in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
+ elif ".position_embedding" in key:
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
+ elif ".token_embedding" in key:
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
+ elif "final_layer_norm" in key:
+ key = key.replace("final_layer_norm", "ln_final")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if "layers" in key and "q_proj" in key:
+ # 三つを結合
+ key_q = key
+ key_k = key.replace("q_proj", "k_proj")
+ key_v = key.replace("q_proj", "v_proj")
+
+ value_q = checkpoint[key_q]
+ value_k = checkpoint[key_k]
+ value_v = checkpoint[key_v]
+ value = torch.cat([value_q, value_k, value_v])
+
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
+ new_sd[new_key] = value
+
+ # 最後の層などを捏造するか
+ if make_dummy_weights:
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
+ keys = list(new_sd.keys())
+ for key in keys:
+ if key.startswith("transformer.resblocks.22."):
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
+
+ # Diffusersに含まれない重みを作っておく
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
+ new_sd["logit_scale"] = torch.tensor(1)
+
+ return new_sd
+
+
+def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None,
+ vae=None):
+ if ckpt_path is not None:
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+ if checkpoint is None: # safetensors または state_dictのckpt
+ checkpoint = {}
+ strict = False
+ else:
+ strict = True
+ if "state_dict" in state_dict:
+ del state_dict["state_dict"]
+ else:
+ # 新しく作る
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
+ checkpoint = {}
+ state_dict = {}
+ strict = False
+
+ def update_sd(prefix, sd):
+ for k, v in sd.items():
+ key = prefix + k
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
+ if save_dtype is not None:
+ v = v.detach().clone().to("cpu").to(save_dtype)
+ state_dict[key] = v
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
+ update_sd("model.diffusion_model.", unet_state_dict)
+
+ # Convert the text encoder model
+ if v2:
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
+ update_sd("cond_stage_model.model.", text_enc_dict)
+ else:
+ text_enc_dict = text_encoder.state_dict()
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
+
+ # Convert the VAE
+ if vae is not None:
+ vae_dict = convert_vae_state_dict(vae.state_dict())
+ update_sd("first_stage_model.", vae_dict)
+
+ # Put together new checkpoint
+ key_count = len(state_dict.keys())
+ new_ckpt = {"state_dict": state_dict}
+
+ # epoch and global_step are sometimes not int
+ try:
+ if "epoch" in checkpoint:
+ epochs += checkpoint["epoch"]
+ if "global_step" in checkpoint:
+ steps += checkpoint["global_step"]
+ except:
+ pass
+
+ new_ckpt["epoch"] = epochs
+ new_ckpt["global_step"] = steps
+
+ if is_safetensors(output_file):
+ # TODO Tensor以外のdictの値を削除したほうがいいか
+ save_file(state_dict, output_file)
+ else:
+ torch.save(new_ckpt, output_file)
+
+ return key_count
+
+
+def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None,
+ use_safetensors=False):
+ if pretrained_model_name_or_path is None:
+ # load default settings for v1/v2
+ if v2:
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
+ else:
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
+
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
+ if vae is None:
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+
+ pipeline = StableDiffusionPipeline(
+ unet=unet,
+ text_encoder=text_encoder,
+ vae=vae,
+ scheduler=scheduler,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=None,
+ )
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
+
+
+VAE_PREFIX = "first_stage_model."
+
+
+def load_vae(vae_id, dtype):
+ print(f"load VAE: {vae_id}")
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
+ # Diffusers local/remote
+ try:
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
+ except EnvironmentError as e:
+ print(f"exception occurs in loading vae: {e}")
+ print("retry with subfolder='vae'")
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
+ return vae
+
+ # local
+ vae_config = create_vae_diffusers_config()
+
+ if vae_id.endswith(".bin"):
+ # SD 1.5 VAE on Huggingface
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
+ else:
+ # StableDiffusion
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
+
+ # vae only or full model
+ full_model = False
+ for vae_key in vae_sd:
+ if vae_key.startswith(VAE_PREFIX):
+ full_model = True
+ break
+ if not full_model:
+ sd = {}
+ for key, value in vae_sd.items():
+ sd[VAE_PREFIX + key] = value
+ vae_sd = sd
+ del sd
+
+ # Convert the VAE model.
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+ return vae
+
+
+# endregion
+
+
+def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
+ max_width, max_height = max_reso
+ max_area = (max_width // divisible) * (max_height // divisible)
+
+ resos = set()
+
+ size = int(math.sqrt(max_area)) * divisible
+ resos.add((size, size))
+
+ size = min_size
+ while size <= max_size:
+ width = size
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
+ resos.add((width, height))
+ resos.add((height, width))
+
+ # # make additional resos
+ # if width >= height and width - divisible >= min_size:
+ # resos.add((width - divisible, height))
+ # resos.add((height, width - divisible))
+ # if height >= width and height - divisible >= min_size:
+ # resos.add((width, height - divisible))
+ # resos.add((height - divisible, width))
+
+ size += divisible
+
+ resos = list(resos)
+ resos.sort()
+ return resos
+
+
+if __name__ == "__main__":
+ resos = make_bucket_resolutions((512, 768))
+ print(len(resos))
+ print(resos)
+ aspect_ratios = [w / h for w, h in resos]
+ print(aspect_ratios)
+
+ ars = set()
+ for ar in aspect_ratios:
+ if ar in ars:
+ print("error! duplicate ar:", ar)
+ ars.add(ar)
diff --git a/toolkit/layers.py b/toolkit/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfc975bfb76ee564021c7ea823a3cdba09aeba48
--- /dev/null
+++ b/toolkit/layers.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.utils.checkpoint import checkpoint
+
+
+class ReductionKernel(nn.Module):
+ # Tensorflow
+ def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None):
+ if device is None:
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ super(ReductionKernel, self).__init__()
+ self.kernel_size = kernel_size
+ self.in_channels = in_channels
+ numpy_kernel = self.build_kernel()
+ self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
+
+ def build_kernel(self):
+ # tensorflow kernel is (height, width, in_channels, out_channels)
+ # pytorch kernel is (out_channels, in_channels, height, width)
+ kernel_size = self.kernel_size
+ channels = self.in_channels
+ kernel_shape = [channels, channels, kernel_size, kernel_size]
+ kernel = np.zeros(kernel_shape, np.float32)
+
+ kernel_value = 1.0 / (kernel_size * kernel_size)
+ for i in range(0, channels):
+ kernel[i, i, :, :] = kernel_value
+ return kernel
+
+ def forward(self, x):
+ return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
+
+
+class CheckpointGradients(nn.Module):
+ def __init__(self, is_gradient_checkpointing=True):
+ super(CheckpointGradients, self).__init__()
+ self.is_gradient_checkpointing = is_gradient_checkpointing
+
+ def forward(self, module, *args, num_chunks=1):
+ if self.is_gradient_checkpointing:
+ return checkpoint(module, *args, num_chunks=self.num_chunks)
+ else:
+ return module(*args)
diff --git a/toolkit/llvae.py b/toolkit/llvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d559bfea01676ad9a9c255d930d693099b1a2c9
--- /dev/null
+++ b/toolkit/llvae.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import itertools
+
+
+class LosslessLatentDecoder(nn.Module):
+ def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
+ super(LosslessLatentDecoder, self).__init__()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.latent_depth = latent_depth
+ self.in_channels = in_channels
+ self.out_channels = int(in_channels // (latent_depth * latent_depth))
+ numpy_kernel = self.build_kernel(in_channels, latent_depth)
+ numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
+ if trainable:
+ self.kernel = nn.Parameter(numpy_kernel)
+ else:
+ self.kernel = numpy_kernel
+
+ def build_kernel(self, in_channels, latent_depth):
+ # my old code from tensorflow.
+ # tensorflow kernel is (height, width, out_channels, in_channels)
+ # pytorch kernel is (in_channels, out_channels, height, width)
+ out_channels = self.out_channels
+
+ # kernel_shape = [kernel_filter_size, kernel_filter_size, out_channels, in_channels] # tensorflow
+ kernel_shape = [in_channels, out_channels, latent_depth, latent_depth] # pytorch
+ kernel = np.zeros(kernel_shape, np.float32)
+
+ # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
+ for c in range(0, out_channels):
+ i = 0
+ for x, y in itertools.product(range(latent_depth), repeat=2):
+ # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
+ kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
+ i += 1
+
+ return kernel
+
+ def forward(self, x):
+ dtype = x.dtype
+ if self.kernel.dtype != dtype:
+ self.kernel = self.kernel.to(dtype=dtype)
+
+ # Deconvolve input tensor with the kernel
+ return nn.functional.conv_transpose2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
+
+
+class LosslessLatentEncoder(nn.Module):
+ def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
+ super(LosslessLatentEncoder, self).__init__()
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.latent_depth = latent_depth
+ self.in_channels = in_channels
+ self.out_channels = int(in_channels * (latent_depth * latent_depth))
+ numpy_kernel = self.build_kernel(in_channels, latent_depth)
+ numpy_kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
+ if trainable:
+ self.kernel = nn.Parameter(numpy_kernel)
+ else:
+ self.kernel = numpy_kernel
+
+
+ def build_kernel(self, in_channels, latent_depth):
+ # my old code from tensorflow.
+ # tensorflow kernel is (height, width, in_channels, out_channels)
+ # pytorch kernel is (out_channels, in_channels, height, width)
+ out_channels = self.out_channels
+
+ # kernel_shape = [latent_depth, latent_depth, in_channels, out_channels] # tensorflow
+ kernel_shape = [out_channels, in_channels, latent_depth, latent_depth] # pytorch
+ kernel = np.zeros(kernel_shape, np.float32)
+
+ # Build the kernel so that a 4 pixel cluster has each pixel come from a separate channel.
+ for c in range(0, in_channels):
+ i = 0
+ for x, y in itertools.product(range(latent_depth), repeat=2):
+ # kernel[y, x, c, c * latent_depth * latent_depth + i] = 1 # tensorflow
+ kernel[c * latent_depth * latent_depth + i, c, y, x] = 1.0 # pytorch
+ i += 1
+ return kernel
+
+ def forward(self, x):
+ dtype = x.dtype
+ if self.kernel.dtype != dtype:
+ self.kernel = self.kernel.to(dtype=dtype)
+ # Convolve input tensor with the kernel
+ return nn.functional.conv2d(x, self.kernel, stride=self.latent_depth, padding=0, groups=1)
+
+
+class LosslessLatentVAE(nn.Module):
+ def __init__(self, in_channels, latent_depth, dtype=torch.float32, trainable=False):
+ super(LosslessLatentVAE, self).__init__()
+ self.latent_depth = latent_depth
+ self.in_channels = in_channels
+ self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype, trainable=trainable)
+ encoder_out_channels = self.encoder.out_channels
+ self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype, trainable=trainable)
+
+ def forward(self, x):
+ latent = self.latent_encoder(x)
+ out = self.latent_decoder(latent)
+ return out
+
+ def encode(self, x):
+ return self.encoder(x)
+
+ def decode(self, x):
+ return self.decoder(x)
+
+
+# test it
+if __name__ == '__main__':
+ import os
+ from PIL import Image
+ import torchvision.transforms as transforms
+ user_path = os.path.expanduser('~')
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.float32
+
+ input_path = os.path.join(user_path, "Pictures/sample_2_512.png")
+ output_path = os.path.join(user_path, "Pictures/sample_2_512_llvae.png")
+ img = Image.open(input_path)
+ img_tensor = transforms.ToTensor()(img)
+ img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
+ print("input_shape: ", list(img_tensor.shape))
+ vae = LosslessLatentVAE(in_channels=3, latent_depth=8, dtype=dtype).to(device=device, dtype=dtype)
+ latent = vae.encode(img_tensor)
+ print("latent_shape: ", list(latent.shape))
+ out_tensor = vae.decode(latent)
+ print("out_shape: ", list(out_tensor.shape))
+
+ mse_loss = nn.MSELoss()
+ mse = mse_loss(img_tensor, out_tensor)
+ print("roundtrip_loss: ", mse.item())
+ out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
+ out_img.save(output_path)
diff --git a/toolkit/logging.py b/toolkit/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b1c8b52a301c50fae9d016e74f9aa65c974882
--- /dev/null
+++ b/toolkit/logging.py
@@ -0,0 +1,84 @@
+from typing import OrderedDict, Optional
+from PIL import Image
+
+from toolkit.config_modules import LoggingConfig
+
+# Base logger class
+# This class does nothing, it's just a placeholder
+class EmptyLogger:
+ def __init__(self, *args, **kwargs) -> None:
+ pass
+
+ # start logging the training
+ def start(self):
+ pass
+
+ # collect the log to send
+ def log(self, *args, **kwargs):
+ pass
+
+ # send the log
+ def commit(self, step: Optional[int] = None):
+ pass
+
+ # log image
+ def log_image(self, *args, **kwargs):
+ pass
+
+ # finish logging
+ def finish(self):
+ pass
+
+# Wandb logger class
+# This class logs the data to wandb
+class WandbLogger(EmptyLogger):
+ def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
+ self.project = project
+ self.run_name = run_name
+ self.config = config
+
+ def start(self):
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
+
+ # send the whole config to wandb
+ run = wandb.init(project=self.project, name=self.run_name, config=self.config)
+ self.run = run
+ self._log = wandb.log # log function
+ self._image = wandb.Image # image object
+
+ def log(self, *args, **kwargs):
+ # when commit is False, wandb increments the step,
+ # but we don't want that to happen, so we set commit=False
+ self._log(*args, **kwargs, commit=False)
+
+ def commit(self, step: Optional[int] = None):
+ # after overall one step is done, we commit the log
+ # by log empty object with commit=True
+ self._log({}, step=step, commit=True)
+
+ def log_image(
+ self,
+ image: Image,
+ id, # sample index
+ caption: str | None = None, # positive prompt
+ *args,
+ **kwargs,
+ ):
+ # create a wandb image object and log it
+ image = self._image(image, caption=caption, *args, **kwargs)
+ self._log({f"sample_{id}": image}, commit=False)
+
+ def finish(self):
+ self.run.finish()
+
+# create logger based on the logging config
+def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
+ if logging_config.use_wandb:
+ project_name = logging_config.project_name
+ run_name = logging_config.run_name
+ return WandbLogger(project=project_name, run_name=run_name, config=all_config)
+ else:
+ return EmptyLogger()
diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c53439a5b6a0c47a74f32876c76ee92cc63bbaa
--- /dev/null
+++ b/toolkit/lora_special.py
@@ -0,0 +1,505 @@
+import copy
+import json
+import math
+import weakref
+import os
+import re
+import sys
+from typing import List, Optional, Dict, Type, Union
+import torch
+from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
+from transformers import CLIPTextModel
+
+from .config_modules import NetworkConfig
+from .lorm import count_parameters
+from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
+from .paths import SD_SCRIPTS_ROOT
+
+sys.path.append(SD_SCRIPTS_ROOT)
+
+from networks.lora import LoRANetwork, get_block_index
+from toolkit.models.DoRA import DoRAModule
+
+from torch.utils.checkpoint import checkpoint
+
+RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
+
+
+# diffusers specific stuff
+LINEAR_MODULES = [
+ 'Linear',
+ 'LoRACompatibleLinear',
+ 'QLinear',
+ # 'GroupNorm',
+]
+CONV_MODULES = [
+ 'Conv2d',
+ 'LoRACompatibleConv',
+ 'QConv2d',
+]
+
+class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ network: 'LoRASpecialNetwork' = None,
+ use_bias: bool = False,
+ **kwargs
+ ):
+ self.can_merge_in = True
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
+ ToolkitModuleMixin.__init__(self, network=network)
+ torch.nn.Module.__init__(self)
+ self.lora_name = lora_name
+ self.orig_module_ref = weakref.ref(org_module)
+ self.scalar = torch.tensor(1.0)
+ # check if parent has bias. if not force use_bias to False
+ if org_module.bias is None:
+ use_bias = False
+
+ if org_module.__class__.__name__ in CONV_MODULES:
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ # if limit_rank:
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
+ # if self.lora_dim != lora_dim:
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
+ # else:
+ self.lora_dim = lora_dim
+
+ if org_module.__class__.__name__ in CONV_MODULES:
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
+
+ # same as microsoft's
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+
+ self.multiplier: Union[float, List[float]] = multiplier
+ # wrap the original module so it doesn't get weights updated
+ self.org_module = [org_module]
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.is_checkpointing = False
+
+ def apply_to(self):
+ self.org_forward = self.org_module[0].forward
+ self.org_module[0].forward = self.forward
+ # del self.org_module
+
+
+class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
+
+ # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
+ # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"]
+ UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"]
+ # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+ LORA_PREFIX_UNET = "lora_unet"
+ PEFT_PREFIX_UNET = "unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
+
+ def __init__(
+ self,
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+ unet,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ block_dims: Optional[List[int]] = None,
+ block_alphas: Optional[List[float]] = None,
+ conv_block_dims: Optional[List[int]] = None,
+ conv_block_alphas: Optional[List[float]] = None,
+ modules_dim: Optional[Dict[str, int]] = None,
+ modules_alpha: Optional[Dict[str, int]] = None,
+ module_class: Type[object] = LoRAModule,
+ varbose: Optional[bool] = False,
+ train_text_encoder: Optional[bool] = True,
+ use_text_encoder_1: bool = True,
+ use_text_encoder_2: bool = True,
+ train_unet: Optional[bool] = True,
+ is_sdxl=False,
+ is_v2=False,
+ is_v3=False,
+ is_pixart: bool = False,
+ is_auraflow: bool = False,
+ is_flux: bool = False,
+ use_bias: bool = False,
+ is_lorm: bool = False,
+ ignore_if_contains = None,
+ only_if_contains = None,
+ parameter_threshold: float = 0.0,
+ attn_only: bool = False,
+ target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
+ target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
+ network_type: str = "lora",
+ full_train_in_out: bool = False,
+ transformer_only: bool = False,
+ peft_format: bool = False,
+ is_assistant_adapter: bool = False,
+ **kwargs
+ ) -> None:
+ """
+ LoRA network: すごく引数が多いが、パターンは以下の通り
+ 1. lora_dimとalphaを指定
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
+ 5. modules_dimとmodules_alphaを指定 (推論用)
+ """
+ # call the parent of the parent we are replacing (LoRANetwork) init
+ torch.nn.Module.__init__(self)
+ ToolkitNetworkMixin.__init__(
+ self,
+ train_text_encoder=train_text_encoder,
+ train_unet=train_unet,
+ is_sdxl=is_sdxl,
+ is_v2=is_v2,
+ is_lorm=is_lorm,
+ **kwargs
+ )
+ if ignore_if_contains is None:
+ ignore_if_contains = []
+ self.ignore_if_contains = ignore_if_contains
+ self.transformer_only = transformer_only
+
+ self.only_if_contains: Union[List, None] = only_if_contains
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.conv_lora_dim = conv_lora_dim
+ self.conv_alpha = conv_alpha
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.is_checkpointing = False
+ self._multiplier: float = 1.0
+ self.is_active: bool = False
+ self.torch_multiplier = None
+ # triggers the state updates
+ self.multiplier = multiplier
+ self.is_sdxl = is_sdxl
+ self.is_v2 = is_v2
+ self.is_v3 = is_v3
+ self.is_pixart = is_pixart
+ self.is_auraflow = is_auraflow
+ self.is_flux = is_flux
+ self.network_type = network_type
+ self.is_assistant_adapter = is_assistant_adapter
+ if self.network_type.lower() == "dora":
+ self.module_class = DoRAModule
+ module_class = DoRAModule
+
+ self.peft_format = peft_format
+
+ # always do peft for flux only for now
+ if self.is_flux or self.is_v3:
+ self.peft_format = True
+
+ if self.peft_format:
+ # no alpha for peft
+ self.alpha = self.lora_dim
+ alpha = self.alpha
+ self.conv_alpha = self.conv_lora_dim
+ conv_alpha = self.conv_alpha
+
+ self.full_train_in_out = full_train_in_out
+
+ if modules_dim is not None:
+ print(f"create LoRA network from weights")
+ elif block_dims is not None:
+ print(f"create LoRA network from block_dims")
+ print(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+ print(f"block_dims: {block_dims}")
+ print(f"block_alphas: {block_alphas}")
+ if conv_block_dims is not None:
+ print(f"conv_block_dims: {conv_block_dims}")
+ print(f"conv_block_alphas: {conv_block_alphas}")
+ else:
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ print(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+ if self.conv_lora_dim is not None:
+ print(
+ f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
+
+ # create module instances
+ def create_modules(
+ is_unet: bool,
+ text_encoder_idx: Optional[int], # None, 1, 2
+ root_module: torch.nn.Module,
+ target_replace_modules: List[torch.nn.Module],
+ ) -> List[LoRAModule]:
+ unet_prefix = self.LORA_PREFIX_UNET
+ if self.peft_format:
+ unet_prefix = self.PEFT_PREFIX_UNET
+ if is_pixart or is_v3 or is_auraflow or is_flux:
+ unet_prefix = f"lora_transformer"
+ if self.peft_format:
+ unet_prefix = "transformer"
+
+ prefix = (
+ unet_prefix
+ if is_unet
+ else (
+ self.LORA_PREFIX_TEXT_ENCODER
+ if text_encoder_idx is None
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
+ )
+ )
+ loras = []
+ skipped = []
+ attached_modules = []
+ lora_shape_dict = {}
+ for name, module in root_module.named_modules():
+ if module.__class__.__name__ in target_replace_modules:
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ in LINEAR_MODULES
+ is_conv2d = child_module.__class__.__name__ in CONV_MODULES
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+
+ lora_name = [prefix, name, child_name]
+ # filter out blank
+ lora_name = [x for x in lora_name if x and x != ""]
+ lora_name = ".".join(lora_name)
+ # if it doesnt have a name, it wil have two dots
+ lora_name.replace("..", ".")
+ clean_name = lora_name
+ if self.peft_format:
+ # we replace this on saving
+ lora_name = lora_name.replace(".", "$$")
+ else:
+ lora_name = lora_name.replace(".", "_")
+
+ skip = False
+ if any([word in clean_name for word in self.ignore_if_contains]):
+ skip = True
+
+ # see if it is over threshold
+ if count_parameters(child_module) < parameter_threshold:
+ skip = True
+
+ if self.transformer_only and self.is_pixart and is_unet:
+ if "transformer_blocks" not in lora_name:
+ skip = True
+ if self.transformer_only and self.is_flux and is_unet:
+ if "transformer_blocks" not in lora_name:
+ skip = True
+ if self.transformer_only and self.is_v3 and is_unet:
+ if "transformer_blocks" not in lora_name:
+ skip = True
+
+ if (is_linear or is_conv2d) and not skip:
+
+ if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]):
+ continue
+
+ dim = None
+ alpha = None
+
+ if modules_dim is not None:
+ # モジュール指定あり
+ if lora_name in modules_dim:
+ dim = modules_dim[lora_name]
+ alpha = modules_alpha[lora_name]
+ elif is_unet and block_dims is not None:
+ # U-Netでblock_dims指定あり
+ block_idx = get_block_index(lora_name)
+ if is_linear or is_conv2d_1x1:
+ dim = block_dims[block_idx]
+ alpha = block_alphas[block_idx]
+ elif conv_block_dims is not None:
+ dim = conv_block_dims[block_idx]
+ alpha = conv_block_alphas[block_idx]
+ else:
+ # 通常、すべて対象とする
+ if is_linear or is_conv2d_1x1:
+ dim = self.lora_dim
+ alpha = self.alpha
+ elif self.conv_lora_dim is not None:
+ dim = self.conv_lora_dim
+ alpha = self.conv_alpha
+
+ if dim is None or dim == 0:
+ # skipした情報を出力
+ if is_linear or is_conv2d_1x1 or (
+ self.conv_lora_dim is not None or conv_block_dims is not None):
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ rank_dropout=rank_dropout,
+ module_dropout=module_dropout,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ )
+ loras.append(lora)
+ lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)
+ ]
+ return loras, skipped
+
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+ # create LoRA for text encoder
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
+ self.text_encoder_loras = []
+ skipped_te = []
+ if train_text_encoder:
+ for i, text_encoder in enumerate(text_encoders):
+ if not use_text_encoder_1 and i == 0:
+ continue
+ if not use_text_encoder_2 and i == 1:
+ continue
+ if len(text_encoders) > 1:
+ index = i + 1
+ print(f"create LoRA for Text Encoder {index}:")
+ else:
+ index = None
+ print(f"create LoRA for Text Encoder:")
+
+ replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
+
+ if self.is_pixart:
+ replace_modules = ["T5EncoderModel"]
+
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules)
+ self.text_encoder_loras.extend(text_encoder_loras)
+ skipped_te += skipped
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
+ target_modules = target_lin_modules
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
+ target_modules += target_conv_modules
+
+ if is_v3:
+ target_modules = ["SD3Transformer2DModel"]
+
+ if is_pixart:
+ target_modules = ["PixArtTransformer2DModel"]
+
+ if is_auraflow:
+ target_modules = ["AuraFlowTransformer2DModel"]
+
+ if is_flux:
+ target_modules = ["FluxTransformer2DModel"]
+
+ if train_unet:
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
+ else:
+ self.unet_loras = []
+ skipped_un = []
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+ skipped = skipped_te + skipped_un
+ if varbose and len(skipped) > 0:
+ print(
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
+ )
+ for name in skipped:
+ print(f"\t{name}")
+
+ self.up_lr_weight: List[float] = None
+ self.down_lr_weight: List[float] = None
+ self.mid_lr_weight: float = None
+ self.block_lr = False
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ if self.full_train_in_out:
+ print("full train in out")
+ # we are going to retrain the main in out layers for VAE change usually
+ if self.is_pixart:
+ transformer: PixArtTransformer2DModel = unet
+ self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
+ self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
+
+ transformer.pos_embed = self.transformer_pos_embed
+ transformer.proj_out = self.transformer_proj_out
+
+ elif self.is_auraflow:
+ transformer: AuraFlowTransformer2DModel = unet
+ self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
+ self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
+
+ transformer.pos_embed = self.transformer_pos_embed
+ transformer.proj_out = self.transformer_proj_out
+
+ else:
+ unet: UNet2DConditionModel = unet
+ unet_conv_in: torch.nn.Conv2d = unet.conv_in
+ unet_conv_out: torch.nn.Conv2d = unet.conv_out
+
+ # clone these and replace their forwards with ours
+ self.unet_conv_in = copy.deepcopy(unet_conv_in)
+ self.unet_conv_out = copy.deepcopy(unet_conv_out)
+ unet.conv_in = self.unet_conv_in
+ unet.conv_out = self.unet_conv_out
+
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+ # call Lora prepare_optimizer_params
+ all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
+
+ if self.full_train_in_out:
+ if self.is_pixart or self.is_auraflow or self.is_flux:
+ all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
+ all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
+ else:
+ all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())})
+ all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())})
+
+ return all_params
+
+
diff --git a/toolkit/lorm.py b/toolkit/lorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cfdb516be12d98e464a7d9a96bc4e19b83b9f91
--- /dev/null
+++ b/toolkit/lorm.py
@@ -0,0 +1,461 @@
+from typing import Union, Tuple, Literal, Optional
+
+import torch
+import torch.nn as nn
+from diffusers import UNet2DConditionModel
+from torch import Tensor
+from tqdm import tqdm
+
+from toolkit.config_modules import LoRMConfig
+
+conv = nn.Conv2d
+lin = nn.Linear
+_size_2_t = Union[int, Tuple[int, int]]
+
+ExtractMode = Union[
+ 'fixed',
+ 'threshold',
+ 'ratio',
+ 'quantile',
+ 'percentage'
+]
+
+LINEAR_MODULES = [
+ 'Linear',
+ 'LoRACompatibleLinear'
+]
+CONV_MODULES = [
+ # 'Conv2d',
+ # 'LoRACompatibleConv'
+]
+
+UNET_TARGET_REPLACE_MODULE = [
+ "Transformer2DModel",
+ # "ResnetBlock2D",
+ "Downsample2D",
+ "Upsample2D",
+]
+
+LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
+
+UNET_TARGET_REPLACE_NAME = [
+ "conv_in",
+ "conv_out",
+ "time_embedding.linear_1",
+ "time_embedding.linear_2",
+]
+
+UNET_MODULES_TO_AVOID = [
+]
+
+
+# Low Rank Convolution
+class LoRMCon2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ lorm_channels: int,
+ out_channels: int,
+ kernel_size: _size_2_t,
+ stride: _size_2_t = 1,
+ padding: Union[str, _size_2_t] = 'same',
+ dilation: _size_2_t = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.lorm_channels = lorm_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.padding_mode = padding_mode
+
+ self.down = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=lorm_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=False,
+ padding_mode=padding_mode,
+ device=device,
+ dtype=dtype
+ )
+
+ # Kernel size on the up is always 1x1.
+ # I don't think you could calculate a dual 3x3, or I can't at least
+
+ self.up = nn.Conv2d(
+ in_channels=lorm_channels,
+ out_channels=out_channels,
+ kernel_size=(1, 1),
+ stride=1,
+ padding='same',
+ dilation=1,
+ groups=1,
+ bias=bias,
+ padding_mode='zeros',
+ device=device,
+ dtype=dtype
+ )
+
+ def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
+ x = input
+ x = self.down(x)
+ x = self.up(x)
+ return x
+
+
+class LoRMLinear(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ lorm_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None
+ ) -> None:
+ super().__init__()
+ self.in_features = in_features
+ self.lorm_features = lorm_features
+ self.out_features = out_features
+
+ self.down = nn.Linear(
+ in_features=in_features,
+ out_features=lorm_features,
+ bias=False,
+ device=device,
+ dtype=dtype
+
+ )
+ self.up = nn.Linear(
+ in_features=lorm_features,
+ out_features=out_features,
+ bias=bias,
+ # bias=True,
+ device=device,
+ dtype=dtype
+ )
+
+ def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
+ x = input
+ x = self.down(x)
+ x = self.up(x)
+ return x
+
+
+def extract_conv(
+ weight: Union[torch.Tensor, nn.Parameter],
+ mode='fixed',
+ mode_param=0,
+ device='cpu'
+) -> Tuple[Tensor, Tensor, int, Tensor]:
+ weight = weight.to(device)
+ out_ch, in_ch, kernel_size, _ = weight.shape
+
+ U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
+ if mode == 'percentage':
+ assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
+ original_params = out_ch * in_ch * kernel_size * kernel_size
+ desired_params = mode_param * original_params
+ # Solve for lora_rank from the equation
+ lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
+ elif mode == 'fixed':
+ lora_rank = mode_param
+ elif mode == 'threshold':
+ assert mode_param >= 0
+ lora_rank = torch.sum(S > mode_param).item()
+ elif mode == 'ratio':
+ assert 1 >= mode_param >= 0
+ min_s = torch.max(S) * mode_param
+ lora_rank = torch.sum(S > min_s).item()
+ elif mode == 'quantile' or mode == 'percentile':
+ assert 1 >= mode_param >= 0
+ s_cum = torch.cumsum(S, dim=0)
+ min_cum_sum = mode_param * torch.sum(S)
+ lora_rank = torch.sum(s_cum < min_cum_sum).item()
+ else:
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
+ lora_rank = max(1, lora_rank)
+ lora_rank = min(out_ch, in_ch, lora_rank)
+ if lora_rank >= out_ch / 2:
+ lora_rank = int(out_ch / 2)
+ print(f"rank is higher than it should be")
+ # print(f"Skipping layer as determined rank is too high")
+ # return None, None, None, None
+ # return weight, 'full'
+
+ U = U[:, :lora_rank]
+ S = S[:lora_rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:lora_rank, :]
+
+ diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
+ extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
+ extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
+ del U, S, Vh, weight
+ return extract_weight_A, extract_weight_B, lora_rank, diff
+
+
+def extract_linear(
+ weight: Union[torch.Tensor, nn.Parameter],
+ mode='fixed',
+ mode_param=0,
+ device='cpu',
+) -> Tuple[Tensor, Tensor, int, Tensor]:
+ weight = weight.to(device)
+ out_ch, in_ch = weight.shape
+
+ U, S, Vh = torch.linalg.svd(weight)
+
+ if mode == 'percentage':
+ assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
+ desired_params = mode_param * out_ch * in_ch
+ # Solve for lora_rank from the equation
+ lora_rank = int(desired_params / (in_ch + out_ch))
+ elif mode == 'fixed':
+ lora_rank = mode_param
+ elif mode == 'threshold':
+ assert mode_param >= 0
+ lora_rank = torch.sum(S > mode_param).item()
+ elif mode == 'ratio':
+ assert 1 >= mode_param >= 0
+ min_s = torch.max(S) * mode_param
+ lora_rank = torch.sum(S > min_s).item()
+ elif mode == 'quantile':
+ assert 1 >= mode_param >= 0
+ s_cum = torch.cumsum(S, dim=0)
+ min_cum_sum = mode_param * torch.sum(S)
+ lora_rank = torch.sum(s_cum < min_cum_sum).item()
+ else:
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
+ lora_rank = max(1, lora_rank)
+ lora_rank = min(out_ch, in_ch, lora_rank)
+ if lora_rank >= out_ch / 2:
+ # print(f"rank is higher than it should be")
+ lora_rank = int(out_ch / 2)
+ # return weight, 'full'
+ # print(f"Skipping layer as determined rank is too high")
+ # return None, None, None, None
+
+ U = U[:, :lora_rank]
+ S = S[:lora_rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:lora_rank, :]
+
+ diff = (weight - U @ Vh).detach()
+ extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
+ extract_weight_B = U.reshape(out_ch, lora_rank).detach()
+ del U, S, Vh, weight
+ return extract_weight_A, extract_weight_B, lora_rank, diff
+
+
+def replace_module_by_path(network, name, module):
+ """Replace a module in a network by its name."""
+ name_parts = name.split('.')
+ current_module = network
+ for part in name_parts[:-1]:
+ current_module = getattr(current_module, part)
+ try:
+ setattr(current_module, name_parts[-1], module)
+ except Exception as e:
+ print(e)
+
+
+def count_parameters(module):
+ return sum(p.numel() for p in module.parameters())
+
+
+def compute_optimal_bias(original_module, linear_down, linear_up, X):
+ Y_original = original_module(X)
+ Y_approx = linear_up(linear_down(X))
+ E = Y_original - Y_approx
+
+ optimal_bias = E.mean(dim=0)
+
+ return optimal_bias
+
+
+def format_with_commas(n):
+ return f"{n:,}"
+
+
+def print_lorm_extract_details(
+ start_num_params: int,
+ end_num_params: int,
+ num_replaced: int,
+):
+ start_formatted = format_with_commas(start_num_params)
+ end_formatted = format_with_commas(end_num_params)
+ num_replaced_formatted = format_with_commas(num_replaced)
+
+ width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
+
+ print(f"Convert UNet result:")
+ print(f" - converted: {num_replaced:>{width},} modules")
+ print(f" - start: {start_num_params:>{width},} params")
+ print(f" - end: {end_num_params:>{width},} params")
+
+
+lorm_ignore_if_contains = [
+ 'proj_out', 'proj_in',
+]
+
+lorm_parameter_threshold = 1000000
+
+
+@torch.no_grad()
+def convert_diffusers_unet_to_lorm(
+ unet: UNet2DConditionModel,
+ config: LoRMConfig,
+):
+ print('Converting UNet to LoRM UNet')
+ start_num_params = count_parameters(unet)
+ named_modules = list(unet.named_modules())
+
+ num_replaced = 0
+
+ pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
+ layer_names_replaced = []
+ converted_modules = []
+ ignore_if_contains = [
+ 'proj_out', 'proj_in',
+ ]
+
+ for name, module in named_modules:
+ module_name = module.__class__.__name__
+ if module_name in UNET_TARGET_REPLACE_MODULE:
+ for child_name, child_module in module.named_modules():
+ new_module: Union[LoRMCon2d, LoRMLinear, None] = None
+ # if child name includes attn, skip it
+ combined_name = combined_name = f"{name}.{child_name}"
+ # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
+ # pass
+
+ lorm_config = config.get_config_for_module(combined_name)
+
+ extract_mode = lorm_config.extract_mode
+ extract_mode_param = lorm_config.extract_mode_param
+ parameter_threshold = lorm_config.parameter_threshold
+
+ if any([word in child_name for word in ignore_if_contains]):
+ pass
+
+ elif child_module.__class__.__name__ in LINEAR_MODULES:
+ if count_parameters(child_module) > parameter_threshold:
+
+ # dtype = child_module.weight.dtype
+ dtype = torch.float32
+ # extract and convert
+ down_weight, up_weight, lora_dim, diff = extract_linear(
+ weight=child_module.weight.clone().detach().float(),
+ mode=extract_mode,
+ mode_param=extract_mode_param,
+ device=child_module.weight.device,
+ )
+ if down_weight is None:
+ continue
+ down_weight = down_weight.to(dtype=dtype)
+ up_weight = up_weight.to(dtype=dtype)
+ bias_weight = None
+ if child_module.bias is not None:
+ bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
+ # linear layer weights = (out_features, in_features)
+ new_module = LoRMLinear(
+ in_features=down_weight.shape[1],
+ lorm_features=lora_dim,
+ out_features=up_weight.shape[0],
+ bias=bias_weight is not None,
+ device=down_weight.device,
+ dtype=down_weight.dtype
+ )
+
+ # replace the weights
+ new_module.down.weight.data = down_weight
+ new_module.up.weight.data = up_weight
+ if bias_weight is not None:
+ new_module.up.bias.data = bias_weight
+ # else:
+ # new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
+
+ # bias_correction = compute_optimal_bias(
+ # child_module,
+ # new_module.down,
+ # new_module.up,
+ # torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
+ # )
+ # new_module.up.bias.data += bias_correction
+
+ elif child_module.__class__.__name__ in CONV_MODULES:
+ if count_parameters(child_module) > parameter_threshold:
+ dtype = child_module.weight.dtype
+ down_weight, up_weight, lora_dim, diff = extract_conv(
+ weight=child_module.weight.clone().detach().float(),
+ mode=extract_mode,
+ mode_param=extract_mode_param,
+ device=child_module.weight.device,
+ )
+ if down_weight is None:
+ continue
+ down_weight = down_weight.to(dtype=dtype)
+ up_weight = up_weight.to(dtype=dtype)
+ bias_weight = None
+ if child_module.bias is not None:
+ bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
+
+ new_module = LoRMCon2d(
+ in_channels=down_weight.shape[1],
+ lorm_channels=lora_dim,
+ out_channels=up_weight.shape[0],
+ kernel_size=child_module.kernel_size,
+ dilation=child_module.dilation,
+ padding=child_module.padding,
+ padding_mode=child_module.padding_mode,
+ stride=child_module.stride,
+ bias=bias_weight is not None,
+ device=down_weight.device,
+ dtype=down_weight.dtype
+ )
+ # replace the weights
+ new_module.down.weight.data = down_weight
+ new_module.up.weight.data = up_weight
+ if bias_weight is not None:
+ new_module.up.bias.data = bias_weight
+
+ if new_module:
+ combined_name = f"{name}.{child_name}"
+ replace_module_by_path(unet, combined_name, new_module)
+ converted_modules.append(new_module)
+ num_replaced += 1
+ layer_names_replaced.append(
+ f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
+
+ pbar.update(1)
+ pbar.close()
+ end_num_params = count_parameters(unet)
+
+ def sorting_key(s):
+ # Extract the number part, remove commas, and convert to integer
+ return int(s.split("-")[1].strip().replace(",", ""))
+
+ sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
+ for layer_name in sorted_layer_names_replaced:
+ print(layer_name)
+
+ print_lorm_extract_details(
+ start_num_params=start_num_params,
+ end_num_params=end_num_params,
+ num_replaced=num_replaced,
+ )
+
+ return converted_modules
diff --git a/toolkit/losses.py b/toolkit/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeea357111f38b54f6b79ea3e73f23f43ba2dbd7
--- /dev/null
+++ b/toolkit/losses.py
@@ -0,0 +1,97 @@
+import torch
+from .llvae import LosslessLatentEncoder
+
+
+def total_variation(image):
+ """
+ Compute normalized total variation.
+ Inputs:
+ - image: PyTorch Variable of shape (N, C, H, W)
+ Returns:
+ - TV: total variation normalized by the number of elements
+ """
+ n_elements = image.shape[1] * image.shape[2] * image.shape[3]
+ return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
+ torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
+
+
+class ComparativeTotalVariation(torch.nn.Module):
+ """
+ Compute the comparative loss in tv between two images. to match their tv
+ """
+
+ def forward(self, pred, target):
+ return torch.abs(total_variation(pred) - total_variation(target))
+
+
+# Gradient penalty
+def get_gradient_penalty(critic, real, fake, device):
+ with torch.autocast(device_type='cuda'):
+ real = real.float()
+ fake = fake.float()
+ alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
+ interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
+ if torch.isnan(interpolates).any():
+ print('d_interpolates is nan')
+ d_interpolates = critic(interpolates)
+ fake = torch.ones(real.size(0), 1, device=device)
+
+ if torch.isnan(d_interpolates).any():
+ print('fake is nan')
+ gradients = torch.autograd.grad(
+ outputs=d_interpolates,
+ inputs=interpolates,
+ grad_outputs=fake,
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+
+ # see if any are nan
+ if torch.isnan(gradients).any():
+ print('gradients is nan')
+
+ gradients = gradients.view(gradients.size(0), -1)
+ gradient_norm = gradients.norm(2, dim=1)
+ gradient_penalty = ((gradient_norm - 1) ** 2).mean()
+ return gradient_penalty.float()
+
+
+class PatternLoss(torch.nn.Module):
+ def __init__(self, pattern_size=4, dtype=torch.float32):
+ super().__init__()
+ self.pattern_size = pattern_size
+ self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype)
+
+ def forward(self, pred, target):
+ pred_latents = self.llvae_encoder(pred)
+ target_latents = self.llvae_encoder(target)
+
+ matrix_pixels = self.pattern_size * self.pattern_size
+
+ color_chans = pred_latents.shape[1] // 3
+ # pytorch
+ r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1)
+ r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1)
+
+ def separated_chan_loss(latent_chan):
+ nonlocal matrix_pixels
+ chan_mean = torch.mean(latent_chan, dim=[1, 2, 3])
+ chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1)
+ chan_loss = None
+ for chan in chan_splits:
+ this_mean = torch.mean(chan, dim=[1, 2, 3])
+ this_chan_loss = torch.abs(this_mean - chan_mean)
+ if chan_loss is None:
+ chan_loss = this_chan_loss
+ else:
+ chan_loss = chan_loss + this_chan_loss
+ chan_loss = chan_loss * (1 / matrix_pixels)
+ return chan_loss
+
+ r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target))
+ g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
+ b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
+ return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333
+
+
diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py
new file mode 100644
index 0000000000000000000000000000000000000000..84021b49cdd3853972721924c1f957203e17e49d
--- /dev/null
+++ b/toolkit/lycoris_special.py
@@ -0,0 +1,373 @@
+import math
+import os
+from typing import Optional, Union, List, Type
+
+import torch
+from lycoris.kohya import LycorisNetwork, LoConModule
+from lycoris.modules.glora import GLoRAModule
+from torch import nn
+from transformers import CLIPTextModel
+from torch.nn import functional as F
+from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
+
+# diffusers specific stuff
+LINEAR_MODULES = [
+ 'Linear',
+ 'LoRACompatibleLinear'
+]
+CONV_MODULES = [
+ 'Conv2d',
+ 'LoRACompatibleConv'
+]
+
+class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin):
+ def __init__(
+ self,
+ lora_name, org_module: nn.Module,
+ multiplier=1.0,
+ lora_dim=4, alpha=1,
+ dropout=0., rank_dropout=0., module_dropout=0.,
+ use_cp=False,
+ network: 'LycorisSpecialNetwork' = None,
+ use_bias=False,
+ **kwargs,
+ ):
+ """ if alpha == 0 or None, alpha is rank (no scaling). """
+ # call super of super
+ ToolkitModuleMixin.__init__(self, network=network)
+ torch.nn.Module.__init__(self)
+ self.lora_name = lora_name
+ self.lora_dim = lora_dim
+ self.cp = False
+
+ # check if parent has bias. if not force use_bias to False
+ if org_module.bias is None:
+ use_bias = False
+
+ self.scalar = nn.Parameter(torch.tensor(0.0))
+ orig_module_name = org_module.__class__.__name__
+ if orig_module_name in CONV_MODULES:
+ self.isconv = True
+ # For general LoCon
+ in_dim = org_module.in_channels
+ k_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ out_dim = org_module.out_channels
+ self.down_op = F.conv2d
+ self.up_op = F.conv2d
+ if use_cp and k_size != (1, 1):
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
+ self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
+ self.cp = True
+ else:
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
+ self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias)
+ elif orig_module_name in LINEAR_MODULES:
+ self.isconv = False
+ self.down_op = F.linear
+ self.up_op = F.linear
+ if orig_module_name == 'GroupNorm':
+ # RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32)
+ in_dim = org_module.num_channels
+ out_dim = org_module.num_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias)
+ else:
+ raise NotImplementedError
+ self.shape = org_module.weight.shape
+
+ if dropout:
+ self.dropout = nn.Dropout(dropout)
+ else:
+ self.dropout = nn.Identity()
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
+
+ # same as microsoft's
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.kaiming_uniform_(self.lora_up.weight)
+ if self.cp:
+ torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
+
+ self.multiplier = multiplier
+ self.org_module = [org_module]
+ self.register_load_state_dict_post_hook(self.load_weight_hook)
+
+ def load_weight_hook(self, *args, **kwargs):
+ self.scalar = nn.Parameter(torch.ones_like(self.scalar))
+
+
+class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
+ UNET_TARGET_REPLACE_MODULE = [
+ "Transformer2DModel",
+ "ResnetBlock2D",
+ "Downsample2D",
+ "Upsample2D",
+ # 'UNet2DConditionModel',
+ # 'Conv2d',
+ # 'Timesteps',
+ # 'TimestepEmbedding',
+ # 'Linear',
+ # 'SiLU',
+ # 'ModuleList',
+ # 'DownBlock2D',
+ # 'ResnetBlock2D', # need
+ # 'GroupNorm',
+ # 'LoRACompatibleConv',
+ # 'LoRACompatibleLinear',
+ # 'Dropout',
+ # 'CrossAttnDownBlock2D', # needed
+ # 'Transformer2DModel', # maybe not, has duplicates
+ # 'BasicTransformerBlock', # duplicates
+ # 'LayerNorm',
+ # 'Attention',
+ # 'FeedForward',
+ # 'GEGLU',
+ # 'UpBlock2D',
+ # 'UNetMidBlock2DCrossAttn'
+ ]
+ UNET_TARGET_REPLACE_NAME = [
+ "conv_in",
+ "conv_out",
+ "time_embedding.linear_1",
+ "time_embedding.linear_2",
+ ]
+ def __init__(
+ self,
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
+ unet,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ rank_dropout: Optional[float] = None,
+ module_dropout: Optional[float] = None,
+ conv_lora_dim: Optional[int] = None,
+ conv_alpha: Optional[float] = None,
+ use_cp: Optional[bool] = False,
+ network_module: Type[object] = LoConSpecialModule,
+ train_unet: bool = True,
+ train_text_encoder: bool = True,
+ use_text_encoder_1: bool = True,
+ use_text_encoder_2: bool = True,
+ use_bias: bool = False,
+ is_lorm: bool = False,
+ **kwargs,
+ ) -> None:
+ # call ToolkitNetworkMixin super
+ ToolkitNetworkMixin.__init__(
+ self,
+ train_text_encoder=train_text_encoder,
+ train_unet=train_unet,
+ is_lorm=is_lorm,
+ **kwargs
+ )
+ # call the parent of the parent LycorisNetwork
+ torch.nn.Module.__init__(self)
+
+ # LyCORIS unique stuff
+ if dropout is None:
+ dropout = 0
+ if rank_dropout is None:
+ rank_dropout = 0
+ if module_dropout is None:
+ module_dropout = 0
+ self.train_unet = train_unet
+ self.train_text_encoder = train_text_encoder
+
+ self.torch_multiplier = None
+ # triggers a tensor update
+ self.multiplier = multiplier
+ self.lora_dim = lora_dim
+
+ if not self.ENABLE_CONV or conv_lora_dim is None:
+ conv_lora_dim = 0
+ conv_alpha = 0
+
+ self.conv_lora_dim = int(conv_lora_dim)
+ if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
+ print('Apply different lora dim for conv layer')
+ print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
+ elif self.conv_lora_dim == 0:
+ print('Disable conv layer')
+
+ self.alpha = alpha
+ self.conv_alpha = float(conv_alpha)
+ if self.conv_lora_dim and self.alpha != self.conv_alpha:
+ print('Apply different alpha value for conv layer')
+ print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
+
+ if 1 >= dropout >= 0:
+ print(f'Use Dropout value: {dropout}')
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ # create module instances
+ def create_modules(
+ prefix,
+ root_module: torch.nn.Module,
+ target_replace_modules,
+ target_replace_names=[]
+ ) -> List[network_module]:
+ print('Create LyCORIS Module')
+ loras = []
+ # remove this
+ named_modules = root_module.named_modules()
+ # add a few to tthe generator
+
+ for name, module in named_modules:
+ module_name = module.__class__.__name__
+ if module_name in target_replace_modules:
+ if module_name in self.MODULE_ALGO_MAP:
+ algo = self.MODULE_ALGO_MAP[module_name]
+ else:
+ algo = network_module
+ for child_name, child_module in module.named_modules():
+ lora_name = prefix + '.' + name + '.' + child_name
+ lora_name = lora_name.replace('.', '_')
+ if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'):
+ print(f"{lora_name}")
+
+ if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0:
+ lora = algo(
+ lora_name, child_module, self.multiplier,
+ self.lora_dim, self.alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ **kwargs
+ )
+ elif child_module.__class__.__name__ in CONV_MODULES:
+ k_size, *_ = child_module.kernel_size
+ if k_size == 1 and lora_dim > 0:
+ lora = algo(
+ lora_name, child_module, self.multiplier,
+ self.lora_dim, self.alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ **kwargs
+ )
+ elif conv_lora_dim > 0:
+ lora = algo(
+ lora_name, child_module, self.multiplier,
+ self.conv_lora_dim, self.conv_alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ **kwargs
+ )
+ else:
+ continue
+ else:
+ continue
+ loras.append(lora)
+ elif name in target_replace_names:
+ if name in self.NAME_ALGO_MAP:
+ algo = self.NAME_ALGO_MAP[name]
+ else:
+ algo = network_module
+ lora_name = prefix + '.' + name
+ lora_name = lora_name.replace('.', '_')
+ if module.__class__.__name__ == 'Linear' and lora_dim > 0:
+ lora = algo(
+ lora_name, module, self.multiplier,
+ self.lora_dim, self.alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ parent=module,
+ network=self,
+ use_bias=use_bias,
+ **kwargs
+ )
+ elif module.__class__.__name__ == 'Conv2d':
+ k_size, *_ = module.kernel_size
+ if k_size == 1 and lora_dim > 0:
+ lora = algo(
+ lora_name, module, self.multiplier,
+ self.lora_dim, self.alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ **kwargs
+ )
+ elif conv_lora_dim > 0:
+ lora = algo(
+ lora_name, module, self.multiplier,
+ self.conv_lora_dim, self.conv_alpha,
+ self.dropout, self.rank_dropout, self.module_dropout,
+ use_cp,
+ network=self,
+ parent=module,
+ use_bias=use_bias,
+ **kwargs
+ )
+ else:
+ continue
+ else:
+ continue
+ loras.append(lora)
+ return loras
+
+ if network_module == GLoRAModule:
+ print('GLoRA enabled, only train transformer')
+ # only train transformer (for GLoRA)
+ LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [
+ "Transformer2DModel",
+ "Attention",
+ ]
+ LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = []
+
+ if isinstance(text_encoder, list):
+ text_encoders = text_encoder
+ use_index = True
+ else:
+ text_encoders = [text_encoder]
+ use_index = False
+
+ self.text_encoder_loras = []
+ if self.train_text_encoder:
+ for i, te in enumerate(text_encoders):
+ if not use_text_encoder_1 and i == 0:
+ continue
+ if not use_text_encoder_2 and i == 1:
+ continue
+ self.text_encoder_loras.extend(create_modules(
+ LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
+ te,
+ LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
+ ))
+ print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
+ if self.train_unet:
+ self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
+ LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
+ else:
+ self.unet_loras = []
+ print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
+
+ self.weights_sd = None
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af11ee9ef52c9c0a42ac34afc3e3fa42c7c4b83d
--- /dev/null
+++ b/toolkit/lycoris_utils.py
@@ -0,0 +1,536 @@
+# heavily based on https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/utils.py
+
+from typing import *
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import torch.linalg as linalg
+
+from tqdm import tqdm
+from collections import OrderedDict
+
+
+def make_sparse(t: torch.Tensor, sparsity=0.95):
+ abs_t = torch.abs(t)
+ np_array = abs_t.detach().cpu().numpy()
+ quan = float(np.quantile(np_array, sparsity))
+ sparse_t = t.masked_fill(abs_t < quan, 0)
+ return sparse_t
+
+
+def extract_conv(
+ weight: Union[torch.Tensor, nn.Parameter],
+ mode='fixed',
+ mode_param=0,
+ device='cpu',
+ is_cp=False,
+) -> Tuple[nn.Parameter, nn.Parameter]:
+ weight = weight.to(device)
+ out_ch, in_ch, kernel_size, _ = weight.shape
+
+ U, S, Vh = linalg.svd(weight.reshape(out_ch, -1))
+
+ if mode == 'fixed':
+ lora_rank = mode_param
+ elif mode == 'threshold':
+ assert mode_param >= 0
+ lora_rank = torch.sum(S > mode_param)
+ elif mode == 'ratio':
+ assert 1 >= mode_param >= 0
+ min_s = torch.max(S) * mode_param
+ lora_rank = torch.sum(S > min_s)
+ elif mode == 'quantile' or mode == 'percentile':
+ assert 1 >= mode_param >= 0
+ s_cum = torch.cumsum(S, dim=0)
+ min_cum_sum = mode_param * torch.sum(S)
+ lora_rank = torch.sum(s_cum < min_cum_sum)
+ else:
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
+ lora_rank = max(1, lora_rank)
+ lora_rank = min(out_ch, in_ch, lora_rank)
+ if lora_rank >= out_ch / 2 and not is_cp:
+ return weight, 'full'
+
+ U = U[:, :lora_rank]
+ S = S[:lora_rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:lora_rank, :]
+
+ diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
+ extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
+ extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
+ del U, S, Vh, weight
+ return (extract_weight_A, extract_weight_B, diff), 'low rank'
+
+
+def extract_linear(
+ weight: Union[torch.Tensor, nn.Parameter],
+ mode='fixed',
+ mode_param=0,
+ device='cpu',
+) -> Tuple[nn.Parameter, nn.Parameter]:
+ weight = weight.to(device)
+ out_ch, in_ch = weight.shape
+
+ U, S, Vh = linalg.svd(weight)
+
+ if mode == 'fixed':
+ lora_rank = mode_param
+ elif mode == 'threshold':
+ assert mode_param >= 0
+ lora_rank = torch.sum(S > mode_param)
+ elif mode == 'ratio':
+ assert 1 >= mode_param >= 0
+ min_s = torch.max(S) * mode_param
+ lora_rank = torch.sum(S > min_s)
+ elif mode == 'quantile' or mode == 'percentile':
+ assert 1 >= mode_param >= 0
+ s_cum = torch.cumsum(S, dim=0)
+ min_cum_sum = mode_param * torch.sum(S)
+ lora_rank = torch.sum(s_cum < min_cum_sum)
+ else:
+ raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
+ lora_rank = max(1, lora_rank)
+ lora_rank = min(out_ch, in_ch, lora_rank)
+ if lora_rank >= out_ch / 2:
+ return weight, 'full'
+
+ U = U[:, :lora_rank]
+ S = S[:lora_rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:lora_rank, :]
+
+ diff = (weight - U @ Vh).detach()
+ extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
+ extract_weight_B = U.reshape(out_ch, lora_rank).detach()
+ del U, S, Vh, weight
+ return (extract_weight_A, extract_weight_B, diff), 'low rank'
+
+
+def extract_diff(
+ base_model,
+ db_model,
+ mode='fixed',
+ linear_mode_param=0,
+ conv_mode_param=0,
+ extract_device='cpu',
+ use_bias=False,
+ sparsity=0.98,
+ small_conv=True,
+ linear_only=False,
+ extract_unet=True,
+ extract_text_encoder=True,
+):
+ meta = OrderedDict()
+
+ UNET_TARGET_REPLACE_MODULE = [
+ "Transformer2DModel",
+ "Attention",
+ "ResnetBlock2D",
+ "Downsample2D",
+ "Upsample2D"
+ ]
+ UNET_TARGET_REPLACE_NAME = [
+ "conv_in",
+ "conv_out",
+ "time_embedding.linear_1",
+ "time_embedding.linear_2",
+ ]
+ if linear_only:
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
+ UNET_TARGET_REPLACE_NAME = [
+ "conv_in",
+ "conv_out",
+ ]
+
+ if not extract_unet:
+ UNET_TARGET_REPLACE_MODULE = []
+ UNET_TARGET_REPLACE_NAME = []
+
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+
+ if not extract_text_encoder:
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = []
+
+ LORA_PREFIX_UNET = 'lora_unet'
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
+
+ def make_state_dict(
+ prefix,
+ root_module: torch.nn.Module,
+ target_module: torch.nn.Module,
+ target_replace_modules,
+ target_replace_names=[]
+ ):
+ loras = {}
+ temp = {}
+ temp_name = {}
+
+ for name, module in root_module.named_modules():
+ if module.__class__.__name__ in target_replace_modules:
+ temp[name] = {}
+ for child_name, child_module in module.named_modules():
+ if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
+ continue
+ temp[name][child_name] = child_module.weight
+ elif name in target_replace_names:
+ temp_name[name] = module.weight
+
+ for name, module in tqdm(list(target_module.named_modules())):
+ if name in temp:
+ weights = temp[name]
+ for child_name, child_module in module.named_modules():
+ lora_name = prefix + '.' + name + '.' + child_name
+ lora_name = lora_name.replace('.', '_')
+ layer = child_module.__class__.__name__
+ if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
+ root_weight = child_module.weight
+ if torch.allclose(root_weight, weights[child_name]):
+ continue
+
+ if layer == 'Linear' or layer == 'LoRACompatibleLinear':
+ weight, decompose_mode = extract_linear(
+ (child_module.weight - weights[child_name]),
+ mode,
+ linear_mode_param,
+ device=extract_device,
+ )
+ if decompose_mode == 'low rank':
+ extract_a, extract_b, diff = weight
+ elif layer == 'Conv2d' or layer == 'LoRACompatibleConv':
+ is_linear = (child_module.weight.shape[2] == 1
+ and child_module.weight.shape[3] == 1)
+ if not is_linear and linear_only:
+ continue
+ weight, decompose_mode = extract_conv(
+ (child_module.weight - weights[child_name]),
+ mode,
+ linear_mode_param if is_linear else conv_mode_param,
+ device=extract_device,
+ )
+ if decompose_mode == 'low rank':
+ extract_a, extract_b, diff = weight
+ if small_conv and not is_linear and decompose_mode == 'low rank':
+ dim = extract_a.size(0)
+ (extract_c, extract_a, _), _ = extract_conv(
+ extract_a.transpose(0, 1),
+ 'fixed', dim,
+ extract_device, True
+ )
+ extract_a = extract_a.transpose(0, 1)
+ extract_c = extract_c.transpose(0, 1)
+ loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
+ diff = child_module.weight - torch.einsum(
+ 'i j k l, j r, p i -> p r k l',
+ extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
+ ).detach().cpu().contiguous()
+ del extract_c
+ else:
+ continue
+ if decompose_mode == 'low rank':
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
+ if use_bias:
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
+
+ indices = sparse_diff.indices().to(torch.int16)
+ values = sparse_diff.values().half()
+ loras[f'{lora_name}.bias_indices'] = indices
+ loras[f'{lora_name}.bias_values'] = values
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
+ del extract_a, extract_b, diff
+ elif decompose_mode == 'full':
+ loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
+ else:
+ raise NotImplementedError
+ elif name in temp_name:
+ weights = temp_name[name]
+ lora_name = prefix + '.' + name
+ lora_name = lora_name.replace('.', '_')
+ layer = module.__class__.__name__
+
+ if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}:
+ root_weight = module.weight
+ if torch.allclose(root_weight, weights):
+ continue
+
+ if layer == 'Linear' or layer == 'LoRACompatibleLinear':
+ weight, decompose_mode = extract_linear(
+ (root_weight - weights),
+ mode,
+ linear_mode_param,
+ device=extract_device,
+ )
+ if decompose_mode == 'low rank':
+ extract_a, extract_b, diff = weight
+ elif layer == 'Conv2d' or layer == 'LoRACompatibleConv':
+ is_linear = (
+ root_weight.shape[2] == 1
+ and root_weight.shape[3] == 1
+ )
+ if not is_linear and linear_only:
+ continue
+ weight, decompose_mode = extract_conv(
+ (root_weight - weights),
+ mode,
+ linear_mode_param if is_linear else conv_mode_param,
+ device=extract_device,
+ )
+ if decompose_mode == 'low rank':
+ extract_a, extract_b, diff = weight
+ if small_conv and not is_linear and decompose_mode == 'low rank':
+ dim = extract_a.size(0)
+ (extract_c, extract_a, _), _ = extract_conv(
+ extract_a.transpose(0, 1),
+ 'fixed', dim,
+ extract_device, True
+ )
+ extract_a = extract_a.transpose(0, 1)
+ extract_c = extract_c.transpose(0, 1)
+ loras[f'{lora_name}.lora_mid.weight'] = extract_c.detach().cpu().contiguous().half()
+ diff = root_weight - torch.einsum(
+ 'i j k l, j r, p i -> p r k l',
+ extract_c, extract_a.flatten(1, -1), extract_b.flatten(1, -1)
+ ).detach().cpu().contiguous()
+ del extract_c
+ else:
+ continue
+ if decompose_mode == 'low rank':
+ loras[f'{lora_name}.lora_down.weight'] = extract_a.detach().cpu().contiguous().half()
+ loras[f'{lora_name}.lora_up.weight'] = extract_b.detach().cpu().contiguous().half()
+ loras[f'{lora_name}.alpha'] = torch.Tensor([extract_a.shape[0]]).half()
+ if use_bias:
+ diff = diff.detach().cpu().reshape(extract_b.size(0), -1)
+ sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce()
+
+ indices = sparse_diff.indices().to(torch.int16)
+ values = sparse_diff.values().half()
+ loras[f'{lora_name}.bias_indices'] = indices
+ loras[f'{lora_name}.bias_values'] = values
+ loras[f'{lora_name}.bias_size'] = torch.tensor(diff.shape).to(torch.int16)
+ del extract_a, extract_b, diff
+ elif decompose_mode == 'full':
+ loras[f'{lora_name}.diff'] = weight.detach().cpu().contiguous().half()
+ else:
+ raise NotImplementedError
+ return loras
+
+ text_encoder_loras = make_state_dict(
+ LORA_PREFIX_TEXT_ENCODER,
+ base_model[0], db_model[0],
+ TEXT_ENCODER_TARGET_REPLACE_MODULE
+ )
+
+ unet_loras = make_state_dict(
+ LORA_PREFIX_UNET,
+ base_model[2], db_model[2],
+ UNET_TARGET_REPLACE_MODULE,
+ UNET_TARGET_REPLACE_NAME
+ )
+ print(len(text_encoder_loras), len(unet_loras))
+ # the | will
+ return (text_encoder_loras | unet_loras), meta
+
+
+def get_module(
+ lyco_state_dict: Dict,
+ lora_name
+):
+ if f'{lora_name}.lora_up.weight' in lyco_state_dict:
+ up = lyco_state_dict[f'{lora_name}.lora_up.weight']
+ down = lyco_state_dict[f'{lora_name}.lora_down.weight']
+ mid = lyco_state_dict.get(f'{lora_name}.lora_mid.weight', None)
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
+ return 'locon', (up, down, mid, alpha)
+ elif f'{lora_name}.hada_w1_a' in lyco_state_dict:
+ w1a = lyco_state_dict[f'{lora_name}.hada_w1_a']
+ w1b = lyco_state_dict[f'{lora_name}.hada_w1_b']
+ w2a = lyco_state_dict[f'{lora_name}.hada_w2_a']
+ w2b = lyco_state_dict[f'{lora_name}.hada_w2_b']
+ t1 = lyco_state_dict.get(f'{lora_name}.hada_t1', None)
+ t2 = lyco_state_dict.get(f'{lora_name}.hada_t2', None)
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
+ return 'hada', (w1a, w1b, w2a, w2b, t1, t2, alpha)
+ elif f'{lora_name}.weight' in lyco_state_dict:
+ weight = lyco_state_dict[f'{lora_name}.weight']
+ on_input = lyco_state_dict.get(f'{lora_name}.on_input', False)
+ return 'ia3', (weight, on_input)
+ elif (f'{lora_name}.lokr_w1' in lyco_state_dict
+ or f'{lora_name}.lokr_w1_a' in lyco_state_dict):
+ w1 = lyco_state_dict.get(f'{lora_name}.lokr_w1', None)
+ w1a = lyco_state_dict.get(f'{lora_name}.lokr_w1_a', None)
+ w1b = lyco_state_dict.get(f'{lora_name}.lokr_w1_b', None)
+ w2 = lyco_state_dict.get(f'{lora_name}.lokr_w2', None)
+ w2a = lyco_state_dict.get(f'{lora_name}.lokr_w2_a', None)
+ w2b = lyco_state_dict.get(f'{lora_name}.lokr_w2_b', None)
+ t1 = lyco_state_dict.get(f'{lora_name}.lokr_t1', None)
+ t2 = lyco_state_dict.get(f'{lora_name}.lokr_t2', None)
+ alpha = lyco_state_dict.get(f'{lora_name}.alpha', None)
+ return 'kron', (w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha)
+ elif f'{lora_name}.diff' in lyco_state_dict:
+ return 'full', lyco_state_dict[f'{lora_name}.diff']
+ else:
+ return 'None', ()
+
+
+def cp_weight_from_conv(
+ up, down, mid
+):
+ up = up.reshape(up.size(0), up.size(1))
+ down = down.reshape(down.size(0), down.size(1))
+ return torch.einsum('m n w h, i m, n j -> i j w h', mid, up, down)
+
+
+def cp_weight(
+ wa, wb, t
+):
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
+
+
+@torch.no_grad()
+def rebuild_weight(module_type, params, orig_weight, scale=1):
+ if orig_weight is None:
+ return orig_weight
+ merged = orig_weight
+ if module_type == 'locon':
+ up, down, mid, alpha = params
+ if alpha is not None:
+ scale *= alpha / up.size(1)
+ if mid is not None:
+ rebuild = cp_weight_from_conv(up, down, mid)
+ else:
+ rebuild = up.reshape(up.size(0), -1) @ down.reshape(down.size(0), -1)
+ merged = orig_weight + rebuild.reshape(orig_weight.shape) * scale
+ del up, down, mid, alpha, params, rebuild
+ elif module_type == 'hada':
+ w1a, w1b, w2a, w2b, t1, t2, alpha = params
+ if alpha is not None:
+ scale *= alpha / w1b.size(0)
+ if t1 is not None:
+ rebuild1 = cp_weight(w1a, w1b, t1)
+ else:
+ rebuild1 = w1a @ w1b
+ if t2 is not None:
+ rebuild2 = cp_weight(w2a, w2b, t2)
+ else:
+ rebuild2 = w2a @ w2b
+ rebuild = (rebuild1 * rebuild2).reshape(orig_weight.shape)
+ merged = orig_weight + rebuild * scale
+ del w1a, w1b, w2a, w2b, t1, t2, alpha, params, rebuild, rebuild1, rebuild2
+ elif module_type == 'ia3':
+ weight, on_input = params
+ if not on_input:
+ weight = weight.reshape(-1, 1)
+ merged = orig_weight + weight * orig_weight * scale
+ del weight, on_input, params
+ elif module_type == 'kron':
+ w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha = params
+ if alpha is not None and (w1b is not None or w2b is not None):
+ scale *= alpha / (w1b.size(0) if w1b else w2b.size(0))
+ if w1a is not None and w1b is not None:
+ if t1:
+ w1 = cp_weight(w1a, w1b, t1)
+ else:
+ w1 = w1a @ w1b
+ if w2a is not None and w2b is not None:
+ if t2:
+ w2 = cp_weight(w2a, w2b, t2)
+ else:
+ w2 = w2a @ w2b
+ rebuild = torch.kron(w1, w2).reshape(orig_weight.shape)
+ merged = orig_weight + rebuild * scale
+ del w1, w1a, w1b, w2, w2a, w2b, t1, t2, alpha, params, rebuild
+ elif module_type == 'full':
+ rebuild = params.reshape(orig_weight.shape)
+ merged = orig_weight + rebuild * scale
+ del params, rebuild
+
+ return merged
+
+
+def merge(
+ base_model,
+ lyco_state_dict,
+ scale: float = 1.0,
+ device='cpu'
+):
+ UNET_TARGET_REPLACE_MODULE = [
+ "Transformer2DModel",
+ "Attention",
+ "ResnetBlock2D",
+ "Downsample2D",
+ "Upsample2D"
+ ]
+ UNET_TARGET_REPLACE_NAME = [
+ "conv_in",
+ "conv_out",
+ "time_embedding.linear_1",
+ "time_embedding.linear_2",
+ ]
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
+ LORA_PREFIX_UNET = 'lora_unet'
+ LORA_PREFIX_TEXT_ENCODER = 'lora_te'
+ merged = 0
+
+ def merge_state_dict(
+ prefix,
+ root_module: torch.nn.Module,
+ lyco_state_dict: Dict[str, torch.Tensor],
+ target_replace_modules,
+ target_replace_names=[]
+ ):
+ nonlocal merged
+ for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'):
+ if module.__class__.__name__ in target_replace_modules:
+ for child_name, child_module in module.named_modules():
+ if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d',
+ 'LoRACompatibleConv'}:
+ continue
+ lora_name = prefix + '.' + name + '.' + child_name
+ lora_name = lora_name.replace('.', '_')
+
+ result = rebuild_weight(*get_module(
+ lyco_state_dict, lora_name
+ ), getattr(child_module, 'weight'), scale)
+ if result is not None:
+ merged += 1
+ child_module.requires_grad_(False)
+ child_module.weight.copy_(result)
+ elif name in target_replace_names:
+ lora_name = prefix + '.' + name
+ lora_name = lora_name.replace('.', '_')
+
+ result = rebuild_weight(*get_module(
+ lyco_state_dict, lora_name
+ ), getattr(module, 'weight'), scale)
+ if result is not None:
+ merged += 1
+ module.requires_grad_(False)
+ module.weight.copy_(result)
+
+ if device == 'cpu':
+ for k, v in tqdm(list(lyco_state_dict.items()), desc='Converting Dtype'):
+ lyco_state_dict[k] = v.float()
+
+ merge_state_dict(
+ LORA_PREFIX_TEXT_ENCODER,
+ base_model[0],
+ lyco_state_dict,
+ TEXT_ENCODER_TARGET_REPLACE_MODULE,
+ UNET_TARGET_REPLACE_NAME
+ )
+ merge_state_dict(
+ LORA_PREFIX_UNET,
+ base_model[2],
+ lyco_state_dict,
+ UNET_TARGET_REPLACE_MODULE,
+ UNET_TARGET_REPLACE_NAME
+ )
+ print(f'{merged} Modules been merged')
diff --git a/toolkit/metadata.py b/toolkit/metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a5c36adae70feb2624a84a3b8dbe05f24ed60ed
--- /dev/null
+++ b/toolkit/metadata.py
@@ -0,0 +1,88 @@
+import json
+from collections import OrderedDict
+from io import BytesIO
+
+import safetensors
+from safetensors import safe_open
+
+from info import software_meta
+from toolkit.train_tools import addnet_hash_legacy
+from toolkit.train_tools import addnet_hash_safetensors
+
+
+def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=True) -> OrderedDict:
+ # stringify the meta and reparse OrderedDict to replace [name] with name
+ meta_string = json.dumps(meta)
+ if name is not None:
+ meta_string = meta_string.replace("[name]", name)
+ save_meta = json.loads(meta_string, object_pairs_hook=OrderedDict)
+ if add_software_info:
+ save_meta["software"] = software_meta
+ # safetensors can only be one level deep
+ for key, value in save_meta.items():
+ # if not float, int, bool, or str, convert to json string
+ if not isinstance(value, str):
+ save_meta[key] = json.dumps(value)
+ # add the pt format
+ save_meta["format"] = "pt"
+ return save_meta
+
+
+def add_model_hash_to_meta(state_dict, meta: OrderedDict) -> OrderedDict:
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
+ save time on indexing the model later."""
+
+ # Because writing user metadata to the file can change the result of
+ # sd_models.model_hash(), only retain the training metadata for purposes of
+ # calculating the hash, as they are meant to be immutable
+ metadata = {k: v for k, v in meta.items() if k.startswith("ss_")}
+
+ bytes = safetensors.torch.save(state_dict, metadata)
+ b = BytesIO(bytes)
+
+ model_hash = addnet_hash_safetensors(b)
+ legacy_hash = addnet_hash_legacy(b)
+ meta["sshs_model_hash"] = model_hash
+ meta["sshs_legacy_hash"] = legacy_hash
+ return meta
+
+
+def add_base_model_info_to_meta(
+ meta: OrderedDict,
+ base_model: str = None,
+ is_v1: bool = False,
+ is_v2: bool = False,
+ is_xl: bool = False,
+) -> OrderedDict:
+ if base_model is not None:
+ meta['ss_base_model'] = base_model
+ elif is_v2:
+ meta['ss_v2'] = True
+ meta['ss_base_model_version'] = 'sd_2.1'
+
+ elif is_xl:
+ meta['ss_base_model_version'] = 'sdxl_1.0'
+ else:
+ # default to v1.5
+ meta['ss_base_model_version'] = 'sd_1.5'
+ return meta
+
+
+def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
+ parsed_meta = OrderedDict()
+ for key, value in meta.items():
+ try:
+ parsed_meta[key] = json.loads(value)
+ except json.decoder.JSONDecodeError:
+ parsed_meta[key] = value
+ return parsed_meta
+
+
+def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
+ try:
+ with safe_open(file_path, framework="pt") as f:
+ metadata = f.metadata()
+ return parse_metadata_from_safetensors(metadata)
+ except Exception as e:
+ print(f"Error loading metadata from {file_path}: {e}")
+ return OrderedDict()
diff --git a/toolkit/models/DoRA.py b/toolkit/models/DoRA.py
new file mode 100644
index 0000000000000000000000000000000000000000..653575e94e640ae2900230d1e3b36f8d3ea5f93e
--- /dev/null
+++ b/toolkit/models/DoRA.py
@@ -0,0 +1,146 @@
+#based off https://github.com/catid/dora/blob/main/dora.py
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import TYPE_CHECKING, Union, List
+
+from optimum.quanto import QBytesTensor, QTensor
+
+from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin
+
+if TYPE_CHECKING:
+ from toolkit.lora_special import LoRASpecialNetwork
+
+# diffusers specific stuff
+LINEAR_MODULES = [
+ 'Linear',
+ 'LoRACompatibleLinear'
+ # 'GroupNorm',
+]
+CONV_MODULES = [
+ 'Conv2d',
+ 'LoRACompatibleConv'
+]
+
+def transpose(weight, fan_in_fan_out):
+ if not fan_in_fan_out:
+ return weight
+
+ if isinstance(weight, torch.nn.Parameter):
+ return torch.nn.Parameter(weight.T)
+ return weight.T
+
+class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
+ # def __init__(self, d_in, d_out, rank=4, weight=None, bias=None):
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ network: 'LoRASpecialNetwork' = None,
+ use_bias: bool = False,
+ **kwargs
+ ):
+ self.can_merge_in = False
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
+ ToolkitModuleMixin.__init__(self, network=network)
+ torch.nn.Module.__init__(self)
+ self.lora_name = lora_name
+ self.scalar = torch.tensor(1.0)
+
+ self.lora_dim = lora_dim
+
+ if org_module.__class__.__name__ in CONV_MODULES:
+ raise NotImplementedError("Convolutional layers are not supported yet")
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ # self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える eng: treat as constant
+
+ self.multiplier: Union[float, List[float]] = multiplier
+ # wrap the original module so it doesn't get weights updated
+ self.org_module = [org_module]
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+ self.is_checkpointing = False
+
+ d_out = org_module.out_features
+ d_in = org_module.in_features
+
+ std_dev = 1 / torch.sqrt(torch.tensor(self.lora_dim).float())
+ # self.lora_up = nn.Parameter(torch.randn(d_out, self.lora_dim) * std_dev) # lora_A
+ # self.lora_down = nn.Parameter(torch.zeros(self.lora_dim, d_in)) # lora_B
+ self.lora_up = nn.Linear(self.lora_dim, d_out, bias=False) # lora_B
+ # self.lora_up.weight.data = torch.randn_like(self.lora_up.weight.data) * std_dev
+ self.lora_up.weight.data = torch.zeros_like(self.lora_up.weight.data)
+ # self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
+ # self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
+ self.lora_down = nn.Linear(d_in, self.lora_dim, bias=False) # lora_A
+ # self.lora_down.weight.data = torch.zeros_like(self.lora_down.weight.data)
+ self.lora_down.weight.data = torch.randn_like(self.lora_down.weight.data) * std_dev
+
+ # m = Magnitude column-wise across output dimension
+ weight = self.get_orig_weight()
+ weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype)
+ lora_weight = self.lora_up.weight @ self.lora_down.weight
+ weight_norm = self._get_weight_norm(weight, lora_weight)
+ self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True)
+
+ def apply_to(self):
+ self.org_forward = self.org_module[0].forward
+ self.org_module[0].forward = self.forward
+ # del self.org_module
+
+ def get_orig_weight(self):
+ weight = self.org_module[0].weight
+ if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
+ return weight.dequantize().data.detach()
+ else:
+ return weight.data.detach()
+
+ def get_orig_bias(self):
+ if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
+ return self.org_module[0].bias.data.detach()
+ return None
+
+ # def dora_forward(self, x, *args, **kwargs):
+ # lora = torch.matmul(self.lora_A, self.lora_B)
+ # adapted = self.get_orig_weight() + lora
+ # column_norm = adapted.norm(p=2, dim=0, keepdim=True)
+ # norm_adapted = adapted / column_norm
+ # calc_weights = self.magnitude * norm_adapted
+ # return F.linear(x, calc_weights, self.get_orig_bias())
+
+ def _get_weight_norm(self, weight, scaled_lora_weight) -> torch.Tensor:
+ # calculate L2 norm of weight matrix, column-wise
+ weight = weight + scaled_lora_weight.to(weight.device)
+ weight_norm = torch.linalg.norm(weight, dim=1)
+ return weight_norm
+
+ def apply_dora(self, x, scaled_lora_weight):
+ # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L192
+ # lora weight is already scaled
+
+ # magnitude = self.lora_magnitude_vector[active_adapter]
+ weight = self.get_orig_weight()
+ weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype)
+ weight_norm = self._get_weight_norm(weight, scaled_lora_weight)
+ # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
+ # "[...] we suggest treating ||V +∆V ||_c in
+ # Eq. (5) as a constant, thereby detaching it from the gradient
+ # graph. This means that while ||V + ∆V ||_c dynamically
+ # reflects the updates of ∆V , it won’t receive any gradient
+ # during backpropagation"
+ weight_norm = weight_norm.detach()
+ dora_weight = transpose(weight + scaled_lora_weight, False)
+ return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight)
diff --git a/toolkit/models/LoRAFormer.py b/toolkit/models/LoRAFormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..78bb460de413129f7b94782476bbd89b112c8542
--- /dev/null
+++ b/toolkit/models/LoRAFormer.py
@@ -0,0 +1,267 @@
+import math
+import weakref
+
+import torch
+import torch.nn as nn
+from typing import TYPE_CHECKING, List, Dict, Any
+from toolkit.models.clip_fusion import ZipperBlock
+from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
+import sys
+from toolkit.paths import REPOS_ROOT
+sys.path.append(REPOS_ROOT)
+from ipadapter.ip_adapter.resampler import Resampler
+from collections import OrderedDict
+
+if TYPE_CHECKING:
+ from toolkit.lora_special import LoRAModule
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, d_model, nhead, dim_feedforward):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
+ self.feed_forward = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ nn.ReLU(),
+ nn.Linear(dim_feedforward, d_model)
+ )
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ def forward(self, x, cross_attn_input):
+ # Self-attention
+ attn_output, _ = self.self_attn(x, x, x)
+ x = self.norm1(x + attn_output)
+
+ # Cross-attention
+ cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input)
+ x = self.norm2(x + cross_attn_output)
+
+ # Feed-forward
+ ff_output = self.feed_forward(x)
+ x = self.norm3(x + ff_output)
+
+ return x
+
+
+class InstantLoRAMidModule(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ lora_module: 'LoRAModule',
+ instant_lora_module: 'InstantLoRAModule',
+ up_shape: list = None,
+ down_shape: list = None,
+ ):
+ super(InstantLoRAMidModule, self).__init__()
+ self.up_shape = up_shape
+ self.down_shape = down_shape
+ self.index = index
+ self.lora_module_ref = weakref.ref(lora_module)
+ self.instant_lora_module_ref = weakref.ref(instant_lora_module)
+
+ self.embed = None
+
+ def down_forward(self, x, *args, **kwargs):
+ # get the embed
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ down_size = math.prod(self.down_shape)
+ down_weight = self.embed[:, :down_size]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if down_weight.shape[0] * 2 == batch_size:
+ down_weight = torch.cat([down_weight] * 2, dim=0)
+
+ weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
+ x_chunks = torch.chunk(x, batch_size, dim=0)
+
+ x_out = []
+ for i in range(batch_size):
+ weight_chunk = weight_chunks[i]
+ x_chunk = x_chunks[i]
+ # reshape
+ weight_chunk = weight_chunk.view(self.down_shape)
+ # check if is conv or linear
+ if len(weight_chunk.shape) == 4:
+ padding = 0
+ if weight_chunk.shape[-1] == 3:
+ padding = 1
+ x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
+ else:
+ # run a simple linear layer with the down weight
+ x_chunk = x_chunk @ weight_chunk.T
+ x_out.append(x_chunk)
+ x = torch.cat(x_out, dim=0)
+ return x
+
+
+ def up_forward(self, x, *args, **kwargs):
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ up_size = math.prod(self.up_shape)
+ up_weight = self.embed[:, -up_size:]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if up_weight.shape[0] * 2 == batch_size:
+ up_weight = torch.cat([up_weight] * 2, dim=0)
+
+ weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
+ x_chunks = torch.chunk(x, batch_size, dim=0)
+
+ x_out = []
+ for i in range(batch_size):
+ weight_chunk = weight_chunks[i]
+ x_chunk = x_chunks[i]
+ # reshape
+ weight_chunk = weight_chunk.view(self.up_shape)
+ # check if is conv or linear
+ if len(weight_chunk.shape) == 4:
+ padding = 0
+ if weight_chunk.shape[-1] == 3:
+ padding = 1
+ x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
+ else:
+ # run a simple linear layer with the down weight
+ x_chunk = x_chunk @ weight_chunk.T
+ x_out.append(x_chunk)
+ x = torch.cat(x_out, dim=0)
+ return x
+
+
+# Initialize the network
+# num_blocks = 8
+# d_model = 1024 # Adjust as needed
+# nhead = 16 # Adjust as needed
+# dim_feedforward = 4096 # Adjust as needed
+# latent_dim = 1695744
+
+class LoRAFormer(torch.nn.Module):
+ def __init__(
+ self,
+ num_blocks,
+ d_model=1024,
+ nhead=16,
+ dim_feedforward=4096,
+ sd: 'StableDiffusion'=None,
+ ):
+ super(LoRAFormer, self).__init__()
+ # self.linear = torch.nn.Linear(2, 1)
+ self.sd_ref = weakref.ref(sd)
+ self.dim = sd.network.lora_dim
+
+ # stores the projection vector. Grabbed by modules
+ self.img_embeds: List[torch.Tensor] = None
+
+ # disable merging in. It is slower on inference
+ self.sd_ref().network.can_merge_in = False
+
+ self.ilora_modules = torch.nn.ModuleList()
+
+ lora_modules = self.sd_ref().network.get_all_modules()
+
+ output_size = 0
+
+ self.embed_lengths = []
+ self.weight_mapping = []
+
+ for idx, lora_module in enumerate(lora_modules):
+ module_dict = lora_module.state_dict()
+ down_shape = list(module_dict['lora_down.weight'].shape)
+ up_shape = list(module_dict['lora_up.weight'].shape)
+
+ self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
+
+ module_size = math.prod(down_shape) + math.prod(up_shape)
+ output_size += module_size
+ self.embed_lengths.append(module_size)
+
+
+ # add a new mid module that will take the original forward and add a vector to it
+ # this will be used to add the vector to the original forward
+ instant_module = InstantLoRAMidModule(
+ idx,
+ lora_module,
+ self,
+ up_shape=up_shape,
+ down_shape=down_shape
+ )
+
+ self.ilora_modules.append(instant_module)
+
+ # replace the LoRA forwards
+ lora_module.lora_down.forward = instant_module.down_forward
+ lora_module.lora_up.forward = instant_module.up_forward
+
+
+ self.output_size = output_size
+
+ self.latent = nn.Parameter(torch.randn(1, output_size))
+ self.latent_proj = nn.Linear(output_size, d_model)
+ self.blocks = nn.ModuleList([
+ TransformerBlock(d_model, nhead, dim_feedforward)
+ for _ in range(num_blocks)
+ ])
+ self.final_proj = nn.Linear(d_model, output_size)
+
+ self.migrate_weight_mapping()
+
+ def migrate_weight_mapping(self):
+ return
+ # # changes the names of the modules to common ones
+ # keymap = self.sd_ref().network.get_keymap()
+ # save_keymap = {}
+ # if keymap is not None:
+ # for ldm_key, diffusers_key in keymap.items():
+ # # invert them
+ # save_keymap[diffusers_key] = ldm_key
+ #
+ # new_keymap = {}
+ # for key, value in self.weight_mapping:
+ # if key in save_keymap:
+ # new_keymap[save_keymap[key]] = value
+ # else:
+ # print(f"Key {key} not found in keymap")
+ # new_keymap[key] = value
+ # self.weight_mapping = new_keymap
+ # else:
+ # print("No keymap found. Using default names")
+ # return
+
+
+ def forward(self, img_embeds):
+ # expand token rank if only rank 2
+ if len(img_embeds.shape) == 2:
+ img_embeds = img_embeds.unsqueeze(1)
+
+ # resample the image embeddings
+ img_embeds = self.resampler(img_embeds)
+ img_embeds = self.proj_module(img_embeds)
+ if len(img_embeds.shape) == 3:
+ # merge the heads
+ img_embeds = img_embeds.mean(dim=1)
+
+ self.img_embeds = []
+ # get all the slices
+ start = 0
+ for length in self.embed_lengths:
+ self.img_embeds.append(img_embeds[:, start:start+length])
+ start += length
+
+
+ def get_additional_save_metadata(self) -> Dict[str, Any]:
+ # save the weight mapping
+ return {
+ "weight_mapping": self.weight_mapping,
+ "num_heads": self.num_heads,
+ "vision_hidden_size": self.vision_hidden_size,
+ "head_dim": self.head_dim,
+ "vision_tokens": self.vision_tokens,
+ "output_size": self.output_size,
+ }
+
diff --git a/toolkit/models/RRDB.py b/toolkit/models/RRDB.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a2ad955309d2d5bb7a19e61812e4a4a761fa2e
--- /dev/null
+++ b/toolkit/models/RRDB.py
@@ -0,0 +1,645 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import functools
+import math
+import re
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import block as B
+
+esrgan_safetensors_keys = ['model.0.weight', 'model.0.bias', 'model.1.sub.0.RDB1.conv1.0.weight',
+ 'model.1.sub.0.RDB1.conv1.0.bias', 'model.1.sub.0.RDB1.conv2.0.weight',
+ 'model.1.sub.0.RDB1.conv2.0.bias', 'model.1.sub.0.RDB1.conv3.0.weight',
+ 'model.1.sub.0.RDB1.conv3.0.bias', 'model.1.sub.0.RDB1.conv4.0.weight',
+ 'model.1.sub.0.RDB1.conv4.0.bias', 'model.1.sub.0.RDB1.conv5.0.weight',
+ 'model.1.sub.0.RDB1.conv5.0.bias', 'model.1.sub.0.RDB2.conv1.0.weight',
+ 'model.1.sub.0.RDB2.conv1.0.bias', 'model.1.sub.0.RDB2.conv2.0.weight',
+ 'model.1.sub.0.RDB2.conv2.0.bias', 'model.1.sub.0.RDB2.conv3.0.weight',
+ 'model.1.sub.0.RDB2.conv3.0.bias', 'model.1.sub.0.RDB2.conv4.0.weight',
+ 'model.1.sub.0.RDB2.conv4.0.bias', 'model.1.sub.0.RDB2.conv5.0.weight',
+ 'model.1.sub.0.RDB2.conv5.0.bias', 'model.1.sub.0.RDB3.conv1.0.weight',
+ 'model.1.sub.0.RDB3.conv1.0.bias', 'model.1.sub.0.RDB3.conv2.0.weight',
+ 'model.1.sub.0.RDB3.conv2.0.bias', 'model.1.sub.0.RDB3.conv3.0.weight',
+ 'model.1.sub.0.RDB3.conv3.0.bias', 'model.1.sub.0.RDB3.conv4.0.weight',
+ 'model.1.sub.0.RDB3.conv4.0.bias', 'model.1.sub.0.RDB3.conv5.0.weight',
+ 'model.1.sub.0.RDB3.conv5.0.bias', 'model.1.sub.1.RDB1.conv1.0.weight',
+ 'model.1.sub.1.RDB1.conv1.0.bias', 'model.1.sub.1.RDB1.conv2.0.weight',
+ 'model.1.sub.1.RDB1.conv2.0.bias', 'model.1.sub.1.RDB1.conv3.0.weight',
+ 'model.1.sub.1.RDB1.conv3.0.bias', 'model.1.sub.1.RDB1.conv4.0.weight',
+ 'model.1.sub.1.RDB1.conv4.0.bias', 'model.1.sub.1.RDB1.conv5.0.weight',
+ 'model.1.sub.1.RDB1.conv5.0.bias', 'model.1.sub.1.RDB2.conv1.0.weight',
+ 'model.1.sub.1.RDB2.conv1.0.bias', 'model.1.sub.1.RDB2.conv2.0.weight',
+ 'model.1.sub.1.RDB2.conv2.0.bias', 'model.1.sub.1.RDB2.conv3.0.weight',
+ 'model.1.sub.1.RDB2.conv3.0.bias', 'model.1.sub.1.RDB2.conv4.0.weight',
+ 'model.1.sub.1.RDB2.conv4.0.bias', 'model.1.sub.1.RDB2.conv5.0.weight',
+ 'model.1.sub.1.RDB2.conv5.0.bias', 'model.1.sub.1.RDB3.conv1.0.weight',
+ 'model.1.sub.1.RDB3.conv1.0.bias', 'model.1.sub.1.RDB3.conv2.0.weight',
+ 'model.1.sub.1.RDB3.conv2.0.bias', 'model.1.sub.1.RDB3.conv3.0.weight',
+ 'model.1.sub.1.RDB3.conv3.0.bias', 'model.1.sub.1.RDB3.conv4.0.weight',
+ 'model.1.sub.1.RDB3.conv4.0.bias', 'model.1.sub.1.RDB3.conv5.0.weight',
+ 'model.1.sub.1.RDB3.conv5.0.bias', 'model.1.sub.2.RDB1.conv1.0.weight',
+ 'model.1.sub.2.RDB1.conv1.0.bias', 'model.1.sub.2.RDB1.conv2.0.weight',
+ 'model.1.sub.2.RDB1.conv2.0.bias', 'model.1.sub.2.RDB1.conv3.0.weight',
+ 'model.1.sub.2.RDB1.conv3.0.bias', 'model.1.sub.2.RDB1.conv4.0.weight',
+ 'model.1.sub.2.RDB1.conv4.0.bias', 'model.1.sub.2.RDB1.conv5.0.weight',
+ 'model.1.sub.2.RDB1.conv5.0.bias', 'model.1.sub.2.RDB2.conv1.0.weight',
+ 'model.1.sub.2.RDB2.conv1.0.bias', 'model.1.sub.2.RDB2.conv2.0.weight',
+ 'model.1.sub.2.RDB2.conv2.0.bias', 'model.1.sub.2.RDB2.conv3.0.weight',
+ 'model.1.sub.2.RDB2.conv3.0.bias', 'model.1.sub.2.RDB2.conv4.0.weight',
+ 'model.1.sub.2.RDB2.conv4.0.bias', 'model.1.sub.2.RDB2.conv5.0.weight',
+ 'model.1.sub.2.RDB2.conv5.0.bias', 'model.1.sub.2.RDB3.conv1.0.weight',
+ 'model.1.sub.2.RDB3.conv1.0.bias', 'model.1.sub.2.RDB3.conv2.0.weight',
+ 'model.1.sub.2.RDB3.conv2.0.bias', 'model.1.sub.2.RDB3.conv3.0.weight',
+ 'model.1.sub.2.RDB3.conv3.0.bias', 'model.1.sub.2.RDB3.conv4.0.weight',
+ 'model.1.sub.2.RDB3.conv4.0.bias', 'model.1.sub.2.RDB3.conv5.0.weight',
+ 'model.1.sub.2.RDB3.conv5.0.bias', 'model.1.sub.3.RDB1.conv1.0.weight',
+ 'model.1.sub.3.RDB1.conv1.0.bias', 'model.1.sub.3.RDB1.conv2.0.weight',
+ 'model.1.sub.3.RDB1.conv2.0.bias', 'model.1.sub.3.RDB1.conv3.0.weight',
+ 'model.1.sub.3.RDB1.conv3.0.bias', 'model.1.sub.3.RDB1.conv4.0.weight',
+ 'model.1.sub.3.RDB1.conv4.0.bias', 'model.1.sub.3.RDB1.conv5.0.weight',
+ 'model.1.sub.3.RDB1.conv5.0.bias', 'model.1.sub.3.RDB2.conv1.0.weight',
+ 'model.1.sub.3.RDB2.conv1.0.bias', 'model.1.sub.3.RDB2.conv2.0.weight',
+ 'model.1.sub.3.RDB2.conv2.0.bias', 'model.1.sub.3.RDB2.conv3.0.weight',
+ 'model.1.sub.3.RDB2.conv3.0.bias', 'model.1.sub.3.RDB2.conv4.0.weight',
+ 'model.1.sub.3.RDB2.conv4.0.bias', 'model.1.sub.3.RDB2.conv5.0.weight',
+ 'model.1.sub.3.RDB2.conv5.0.bias', 'model.1.sub.3.RDB3.conv1.0.weight',
+ 'model.1.sub.3.RDB3.conv1.0.bias', 'model.1.sub.3.RDB3.conv2.0.weight',
+ 'model.1.sub.3.RDB3.conv2.0.bias', 'model.1.sub.3.RDB3.conv3.0.weight',
+ 'model.1.sub.3.RDB3.conv3.0.bias', 'model.1.sub.3.RDB3.conv4.0.weight',
+ 'model.1.sub.3.RDB3.conv4.0.bias', 'model.1.sub.3.RDB3.conv5.0.weight',
+ 'model.1.sub.3.RDB3.conv5.0.bias', 'model.1.sub.4.RDB1.conv1.0.weight',
+ 'model.1.sub.4.RDB1.conv1.0.bias', 'model.1.sub.4.RDB1.conv2.0.weight',
+ 'model.1.sub.4.RDB1.conv2.0.bias', 'model.1.sub.4.RDB1.conv3.0.weight',
+ 'model.1.sub.4.RDB1.conv3.0.bias', 'model.1.sub.4.RDB1.conv4.0.weight',
+ 'model.1.sub.4.RDB1.conv4.0.bias', 'model.1.sub.4.RDB1.conv5.0.weight',
+ 'model.1.sub.4.RDB1.conv5.0.bias', 'model.1.sub.4.RDB2.conv1.0.weight',
+ 'model.1.sub.4.RDB2.conv1.0.bias', 'model.1.sub.4.RDB2.conv2.0.weight',
+ 'model.1.sub.4.RDB2.conv2.0.bias', 'model.1.sub.4.RDB2.conv3.0.weight',
+ 'model.1.sub.4.RDB2.conv3.0.bias', 'model.1.sub.4.RDB2.conv4.0.weight',
+ 'model.1.sub.4.RDB2.conv4.0.bias', 'model.1.sub.4.RDB2.conv5.0.weight',
+ 'model.1.sub.4.RDB2.conv5.0.bias', 'model.1.sub.4.RDB3.conv1.0.weight',
+ 'model.1.sub.4.RDB3.conv1.0.bias', 'model.1.sub.4.RDB3.conv2.0.weight',
+ 'model.1.sub.4.RDB3.conv2.0.bias', 'model.1.sub.4.RDB3.conv3.0.weight',
+ 'model.1.sub.4.RDB3.conv3.0.bias', 'model.1.sub.4.RDB3.conv4.0.weight',
+ 'model.1.sub.4.RDB3.conv4.0.bias', 'model.1.sub.4.RDB3.conv5.0.weight',
+ 'model.1.sub.4.RDB3.conv5.0.bias', 'model.1.sub.5.RDB1.conv1.0.weight',
+ 'model.1.sub.5.RDB1.conv1.0.bias', 'model.1.sub.5.RDB1.conv2.0.weight',
+ 'model.1.sub.5.RDB1.conv2.0.bias', 'model.1.sub.5.RDB1.conv3.0.weight',
+ 'model.1.sub.5.RDB1.conv3.0.bias', 'model.1.sub.5.RDB1.conv4.0.weight',
+ 'model.1.sub.5.RDB1.conv4.0.bias', 'model.1.sub.5.RDB1.conv5.0.weight',
+ 'model.1.sub.5.RDB1.conv5.0.bias', 'model.1.sub.5.RDB2.conv1.0.weight',
+ 'model.1.sub.5.RDB2.conv1.0.bias', 'model.1.sub.5.RDB2.conv2.0.weight',
+ 'model.1.sub.5.RDB2.conv2.0.bias', 'model.1.sub.5.RDB2.conv3.0.weight',
+ 'model.1.sub.5.RDB2.conv3.0.bias', 'model.1.sub.5.RDB2.conv4.0.weight',
+ 'model.1.sub.5.RDB2.conv4.0.bias', 'model.1.sub.5.RDB2.conv5.0.weight',
+ 'model.1.sub.5.RDB2.conv5.0.bias', 'model.1.sub.5.RDB3.conv1.0.weight',
+ 'model.1.sub.5.RDB3.conv1.0.bias', 'model.1.sub.5.RDB3.conv2.0.weight',
+ 'model.1.sub.5.RDB3.conv2.0.bias', 'model.1.sub.5.RDB3.conv3.0.weight',
+ 'model.1.sub.5.RDB3.conv3.0.bias', 'model.1.sub.5.RDB3.conv4.0.weight',
+ 'model.1.sub.5.RDB3.conv4.0.bias', 'model.1.sub.5.RDB3.conv5.0.weight',
+ 'model.1.sub.5.RDB3.conv5.0.bias', 'model.1.sub.6.RDB1.conv1.0.weight',
+ 'model.1.sub.6.RDB1.conv1.0.bias', 'model.1.sub.6.RDB1.conv2.0.weight',
+ 'model.1.sub.6.RDB1.conv2.0.bias', 'model.1.sub.6.RDB1.conv3.0.weight',
+ 'model.1.sub.6.RDB1.conv3.0.bias', 'model.1.sub.6.RDB1.conv4.0.weight',
+ 'model.1.sub.6.RDB1.conv4.0.bias', 'model.1.sub.6.RDB1.conv5.0.weight',
+ 'model.1.sub.6.RDB1.conv5.0.bias', 'model.1.sub.6.RDB2.conv1.0.weight',
+ 'model.1.sub.6.RDB2.conv1.0.bias', 'model.1.sub.6.RDB2.conv2.0.weight',
+ 'model.1.sub.6.RDB2.conv2.0.bias', 'model.1.sub.6.RDB2.conv3.0.weight',
+ 'model.1.sub.6.RDB2.conv3.0.bias', 'model.1.sub.6.RDB2.conv4.0.weight',
+ 'model.1.sub.6.RDB2.conv4.0.bias', 'model.1.sub.6.RDB2.conv5.0.weight',
+ 'model.1.sub.6.RDB2.conv5.0.bias', 'model.1.sub.6.RDB3.conv1.0.weight',
+ 'model.1.sub.6.RDB3.conv1.0.bias', 'model.1.sub.6.RDB3.conv2.0.weight',
+ 'model.1.sub.6.RDB3.conv2.0.bias', 'model.1.sub.6.RDB3.conv3.0.weight',
+ 'model.1.sub.6.RDB3.conv3.0.bias', 'model.1.sub.6.RDB3.conv4.0.weight',
+ 'model.1.sub.6.RDB3.conv4.0.bias', 'model.1.sub.6.RDB3.conv5.0.weight',
+ 'model.1.sub.6.RDB3.conv5.0.bias', 'model.1.sub.7.RDB1.conv1.0.weight',
+ 'model.1.sub.7.RDB1.conv1.0.bias', 'model.1.sub.7.RDB1.conv2.0.weight',
+ 'model.1.sub.7.RDB1.conv2.0.bias', 'model.1.sub.7.RDB1.conv3.0.weight',
+ 'model.1.sub.7.RDB1.conv3.0.bias', 'model.1.sub.7.RDB1.conv4.0.weight',
+ 'model.1.sub.7.RDB1.conv4.0.bias', 'model.1.sub.7.RDB1.conv5.0.weight',
+ 'model.1.sub.7.RDB1.conv5.0.bias', 'model.1.sub.7.RDB2.conv1.0.weight',
+ 'model.1.sub.7.RDB2.conv1.0.bias', 'model.1.sub.7.RDB2.conv2.0.weight',
+ 'model.1.sub.7.RDB2.conv2.0.bias', 'model.1.sub.7.RDB2.conv3.0.weight',
+ 'model.1.sub.7.RDB2.conv3.0.bias', 'model.1.sub.7.RDB2.conv4.0.weight',
+ 'model.1.sub.7.RDB2.conv4.0.bias', 'model.1.sub.7.RDB2.conv5.0.weight',
+ 'model.1.sub.7.RDB2.conv5.0.bias', 'model.1.sub.7.RDB3.conv1.0.weight',
+ 'model.1.sub.7.RDB3.conv1.0.bias', 'model.1.sub.7.RDB3.conv2.0.weight',
+ 'model.1.sub.7.RDB3.conv2.0.bias', 'model.1.sub.7.RDB3.conv3.0.weight',
+ 'model.1.sub.7.RDB3.conv3.0.bias', 'model.1.sub.7.RDB3.conv4.0.weight',
+ 'model.1.sub.7.RDB3.conv4.0.bias', 'model.1.sub.7.RDB3.conv5.0.weight',
+ 'model.1.sub.7.RDB3.conv5.0.bias', 'model.1.sub.8.RDB1.conv1.0.weight',
+ 'model.1.sub.8.RDB1.conv1.0.bias', 'model.1.sub.8.RDB1.conv2.0.weight',
+ 'model.1.sub.8.RDB1.conv2.0.bias', 'model.1.sub.8.RDB1.conv3.0.weight',
+ 'model.1.sub.8.RDB1.conv3.0.bias', 'model.1.sub.8.RDB1.conv4.0.weight',
+ 'model.1.sub.8.RDB1.conv4.0.bias', 'model.1.sub.8.RDB1.conv5.0.weight',
+ 'model.1.sub.8.RDB1.conv5.0.bias', 'model.1.sub.8.RDB2.conv1.0.weight',
+ 'model.1.sub.8.RDB2.conv1.0.bias', 'model.1.sub.8.RDB2.conv2.0.weight',
+ 'model.1.sub.8.RDB2.conv2.0.bias', 'model.1.sub.8.RDB2.conv3.0.weight',
+ 'model.1.sub.8.RDB2.conv3.0.bias', 'model.1.sub.8.RDB2.conv4.0.weight',
+ 'model.1.sub.8.RDB2.conv4.0.bias', 'model.1.sub.8.RDB2.conv5.0.weight',
+ 'model.1.sub.8.RDB2.conv5.0.bias', 'model.1.sub.8.RDB3.conv1.0.weight',
+ 'model.1.sub.8.RDB3.conv1.0.bias', 'model.1.sub.8.RDB3.conv2.0.weight',
+ 'model.1.sub.8.RDB3.conv2.0.bias', 'model.1.sub.8.RDB3.conv3.0.weight',
+ 'model.1.sub.8.RDB3.conv3.0.bias', 'model.1.sub.8.RDB3.conv4.0.weight',
+ 'model.1.sub.8.RDB3.conv4.0.bias', 'model.1.sub.8.RDB3.conv5.0.weight',
+ 'model.1.sub.8.RDB3.conv5.0.bias', 'model.1.sub.9.RDB1.conv1.0.weight',
+ 'model.1.sub.9.RDB1.conv1.0.bias', 'model.1.sub.9.RDB1.conv2.0.weight',
+ 'model.1.sub.9.RDB1.conv2.0.bias', 'model.1.sub.9.RDB1.conv3.0.weight',
+ 'model.1.sub.9.RDB1.conv3.0.bias', 'model.1.sub.9.RDB1.conv4.0.weight',
+ 'model.1.sub.9.RDB1.conv4.0.bias', 'model.1.sub.9.RDB1.conv5.0.weight',
+ 'model.1.sub.9.RDB1.conv5.0.bias', 'model.1.sub.9.RDB2.conv1.0.weight',
+ 'model.1.sub.9.RDB2.conv1.0.bias', 'model.1.sub.9.RDB2.conv2.0.weight',
+ 'model.1.sub.9.RDB2.conv2.0.bias', 'model.1.sub.9.RDB2.conv3.0.weight',
+ 'model.1.sub.9.RDB2.conv3.0.bias', 'model.1.sub.9.RDB2.conv4.0.weight',
+ 'model.1.sub.9.RDB2.conv4.0.bias', 'model.1.sub.9.RDB2.conv5.0.weight',
+ 'model.1.sub.9.RDB2.conv5.0.bias', 'model.1.sub.9.RDB3.conv1.0.weight',
+ 'model.1.sub.9.RDB3.conv1.0.bias', 'model.1.sub.9.RDB3.conv2.0.weight',
+ 'model.1.sub.9.RDB3.conv2.0.bias', 'model.1.sub.9.RDB3.conv3.0.weight',
+ 'model.1.sub.9.RDB3.conv3.0.bias', 'model.1.sub.9.RDB3.conv4.0.weight',
+ 'model.1.sub.9.RDB3.conv4.0.bias', 'model.1.sub.9.RDB3.conv5.0.weight',
+ 'model.1.sub.9.RDB3.conv5.0.bias', 'model.1.sub.10.RDB1.conv1.0.weight',
+ 'model.1.sub.10.RDB1.conv1.0.bias', 'model.1.sub.10.RDB1.conv2.0.weight',
+ 'model.1.sub.10.RDB1.conv2.0.bias', 'model.1.sub.10.RDB1.conv3.0.weight',
+ 'model.1.sub.10.RDB1.conv3.0.bias', 'model.1.sub.10.RDB1.conv4.0.weight',
+ 'model.1.sub.10.RDB1.conv4.0.bias', 'model.1.sub.10.RDB1.conv5.0.weight',
+ 'model.1.sub.10.RDB1.conv5.0.bias', 'model.1.sub.10.RDB2.conv1.0.weight',
+ 'model.1.sub.10.RDB2.conv1.0.bias', 'model.1.sub.10.RDB2.conv2.0.weight',
+ 'model.1.sub.10.RDB2.conv2.0.bias', 'model.1.sub.10.RDB2.conv3.0.weight',
+ 'model.1.sub.10.RDB2.conv3.0.bias', 'model.1.sub.10.RDB2.conv4.0.weight',
+ 'model.1.sub.10.RDB2.conv4.0.bias', 'model.1.sub.10.RDB2.conv5.0.weight',
+ 'model.1.sub.10.RDB2.conv5.0.bias', 'model.1.sub.10.RDB3.conv1.0.weight',
+ 'model.1.sub.10.RDB3.conv1.0.bias', 'model.1.sub.10.RDB3.conv2.0.weight',
+ 'model.1.sub.10.RDB3.conv2.0.bias', 'model.1.sub.10.RDB3.conv3.0.weight',
+ 'model.1.sub.10.RDB3.conv3.0.bias', 'model.1.sub.10.RDB3.conv4.0.weight',
+ 'model.1.sub.10.RDB3.conv4.0.bias', 'model.1.sub.10.RDB3.conv5.0.weight',
+ 'model.1.sub.10.RDB3.conv5.0.bias', 'model.1.sub.11.RDB1.conv1.0.weight',
+ 'model.1.sub.11.RDB1.conv1.0.bias', 'model.1.sub.11.RDB1.conv2.0.weight',
+ 'model.1.sub.11.RDB1.conv2.0.bias', 'model.1.sub.11.RDB1.conv3.0.weight',
+ 'model.1.sub.11.RDB1.conv3.0.bias', 'model.1.sub.11.RDB1.conv4.0.weight',
+ 'model.1.sub.11.RDB1.conv4.0.bias', 'model.1.sub.11.RDB1.conv5.0.weight',
+ 'model.1.sub.11.RDB1.conv5.0.bias', 'model.1.sub.11.RDB2.conv1.0.weight',
+ 'model.1.sub.11.RDB2.conv1.0.bias', 'model.1.sub.11.RDB2.conv2.0.weight',
+ 'model.1.sub.11.RDB2.conv2.0.bias', 'model.1.sub.11.RDB2.conv3.0.weight',
+ 'model.1.sub.11.RDB2.conv3.0.bias', 'model.1.sub.11.RDB2.conv4.0.weight',
+ 'model.1.sub.11.RDB2.conv4.0.bias', 'model.1.sub.11.RDB2.conv5.0.weight',
+ 'model.1.sub.11.RDB2.conv5.0.bias', 'model.1.sub.11.RDB3.conv1.0.weight',
+ 'model.1.sub.11.RDB3.conv1.0.bias', 'model.1.sub.11.RDB3.conv2.0.weight',
+ 'model.1.sub.11.RDB3.conv2.0.bias', 'model.1.sub.11.RDB3.conv3.0.weight',
+ 'model.1.sub.11.RDB3.conv3.0.bias', 'model.1.sub.11.RDB3.conv4.0.weight',
+ 'model.1.sub.11.RDB3.conv4.0.bias', 'model.1.sub.11.RDB3.conv5.0.weight',
+ 'model.1.sub.11.RDB3.conv5.0.bias', 'model.1.sub.12.RDB1.conv1.0.weight',
+ 'model.1.sub.12.RDB1.conv1.0.bias', 'model.1.sub.12.RDB1.conv2.0.weight',
+ 'model.1.sub.12.RDB1.conv2.0.bias', 'model.1.sub.12.RDB1.conv3.0.weight',
+ 'model.1.sub.12.RDB1.conv3.0.bias', 'model.1.sub.12.RDB1.conv4.0.weight',
+ 'model.1.sub.12.RDB1.conv4.0.bias', 'model.1.sub.12.RDB1.conv5.0.weight',
+ 'model.1.sub.12.RDB1.conv5.0.bias', 'model.1.sub.12.RDB2.conv1.0.weight',
+ 'model.1.sub.12.RDB2.conv1.0.bias', 'model.1.sub.12.RDB2.conv2.0.weight',
+ 'model.1.sub.12.RDB2.conv2.0.bias', 'model.1.sub.12.RDB2.conv3.0.weight',
+ 'model.1.sub.12.RDB2.conv3.0.bias', 'model.1.sub.12.RDB2.conv4.0.weight',
+ 'model.1.sub.12.RDB2.conv4.0.bias', 'model.1.sub.12.RDB2.conv5.0.weight',
+ 'model.1.sub.12.RDB2.conv5.0.bias', 'model.1.sub.12.RDB3.conv1.0.weight',
+ 'model.1.sub.12.RDB3.conv1.0.bias', 'model.1.sub.12.RDB3.conv2.0.weight',
+ 'model.1.sub.12.RDB3.conv2.0.bias', 'model.1.sub.12.RDB3.conv3.0.weight',
+ 'model.1.sub.12.RDB3.conv3.0.bias', 'model.1.sub.12.RDB3.conv4.0.weight',
+ 'model.1.sub.12.RDB3.conv4.0.bias', 'model.1.sub.12.RDB3.conv5.0.weight',
+ 'model.1.sub.12.RDB3.conv5.0.bias', 'model.1.sub.13.RDB1.conv1.0.weight',
+ 'model.1.sub.13.RDB1.conv1.0.bias', 'model.1.sub.13.RDB1.conv2.0.weight',
+ 'model.1.sub.13.RDB1.conv2.0.bias', 'model.1.sub.13.RDB1.conv3.0.weight',
+ 'model.1.sub.13.RDB1.conv3.0.bias', 'model.1.sub.13.RDB1.conv4.0.weight',
+ 'model.1.sub.13.RDB1.conv4.0.bias', 'model.1.sub.13.RDB1.conv5.0.weight',
+ 'model.1.sub.13.RDB1.conv5.0.bias', 'model.1.sub.13.RDB2.conv1.0.weight',
+ 'model.1.sub.13.RDB2.conv1.0.bias', 'model.1.sub.13.RDB2.conv2.0.weight',
+ 'model.1.sub.13.RDB2.conv2.0.bias', 'model.1.sub.13.RDB2.conv3.0.weight',
+ 'model.1.sub.13.RDB2.conv3.0.bias', 'model.1.sub.13.RDB2.conv4.0.weight',
+ 'model.1.sub.13.RDB2.conv4.0.bias', 'model.1.sub.13.RDB2.conv5.0.weight',
+ 'model.1.sub.13.RDB2.conv5.0.bias', 'model.1.sub.13.RDB3.conv1.0.weight',
+ 'model.1.sub.13.RDB3.conv1.0.bias', 'model.1.sub.13.RDB3.conv2.0.weight',
+ 'model.1.sub.13.RDB3.conv2.0.bias', 'model.1.sub.13.RDB3.conv3.0.weight',
+ 'model.1.sub.13.RDB3.conv3.0.bias', 'model.1.sub.13.RDB3.conv4.0.weight',
+ 'model.1.sub.13.RDB3.conv4.0.bias', 'model.1.sub.13.RDB3.conv5.0.weight',
+ 'model.1.sub.13.RDB3.conv5.0.bias', 'model.1.sub.14.RDB1.conv1.0.weight',
+ 'model.1.sub.14.RDB1.conv1.0.bias', 'model.1.sub.14.RDB1.conv2.0.weight',
+ 'model.1.sub.14.RDB1.conv2.0.bias', 'model.1.sub.14.RDB1.conv3.0.weight',
+ 'model.1.sub.14.RDB1.conv3.0.bias', 'model.1.sub.14.RDB1.conv4.0.weight',
+ 'model.1.sub.14.RDB1.conv4.0.bias', 'model.1.sub.14.RDB1.conv5.0.weight',
+ 'model.1.sub.14.RDB1.conv5.0.bias', 'model.1.sub.14.RDB2.conv1.0.weight',
+ 'model.1.sub.14.RDB2.conv1.0.bias', 'model.1.sub.14.RDB2.conv2.0.weight',
+ 'model.1.sub.14.RDB2.conv2.0.bias', 'model.1.sub.14.RDB2.conv3.0.weight',
+ 'model.1.sub.14.RDB2.conv3.0.bias', 'model.1.sub.14.RDB2.conv4.0.weight',
+ 'model.1.sub.14.RDB2.conv4.0.bias', 'model.1.sub.14.RDB2.conv5.0.weight',
+ 'model.1.sub.14.RDB2.conv5.0.bias', 'model.1.sub.14.RDB3.conv1.0.weight',
+ 'model.1.sub.14.RDB3.conv1.0.bias', 'model.1.sub.14.RDB3.conv2.0.weight',
+ 'model.1.sub.14.RDB3.conv2.0.bias', 'model.1.sub.14.RDB3.conv3.0.weight',
+ 'model.1.sub.14.RDB3.conv3.0.bias', 'model.1.sub.14.RDB3.conv4.0.weight',
+ 'model.1.sub.14.RDB3.conv4.0.bias', 'model.1.sub.14.RDB3.conv5.0.weight',
+ 'model.1.sub.14.RDB3.conv5.0.bias', 'model.1.sub.15.RDB1.conv1.0.weight',
+ 'model.1.sub.15.RDB1.conv1.0.bias', 'model.1.sub.15.RDB1.conv2.0.weight',
+ 'model.1.sub.15.RDB1.conv2.0.bias', 'model.1.sub.15.RDB1.conv3.0.weight',
+ 'model.1.sub.15.RDB1.conv3.0.bias', 'model.1.sub.15.RDB1.conv4.0.weight',
+ 'model.1.sub.15.RDB1.conv4.0.bias', 'model.1.sub.15.RDB1.conv5.0.weight',
+ 'model.1.sub.15.RDB1.conv5.0.bias', 'model.1.sub.15.RDB2.conv1.0.weight',
+ 'model.1.sub.15.RDB2.conv1.0.bias', 'model.1.sub.15.RDB2.conv2.0.weight',
+ 'model.1.sub.15.RDB2.conv2.0.bias', 'model.1.sub.15.RDB2.conv3.0.weight',
+ 'model.1.sub.15.RDB2.conv3.0.bias', 'model.1.sub.15.RDB2.conv4.0.weight',
+ 'model.1.sub.15.RDB2.conv4.0.bias', 'model.1.sub.15.RDB2.conv5.0.weight',
+ 'model.1.sub.15.RDB2.conv5.0.bias', 'model.1.sub.15.RDB3.conv1.0.weight',
+ 'model.1.sub.15.RDB3.conv1.0.bias', 'model.1.sub.15.RDB3.conv2.0.weight',
+ 'model.1.sub.15.RDB3.conv2.0.bias', 'model.1.sub.15.RDB3.conv3.0.weight',
+ 'model.1.sub.15.RDB3.conv3.0.bias', 'model.1.sub.15.RDB3.conv4.0.weight',
+ 'model.1.sub.15.RDB3.conv4.0.bias', 'model.1.sub.15.RDB3.conv5.0.weight',
+ 'model.1.sub.15.RDB3.conv5.0.bias', 'model.1.sub.16.RDB1.conv1.0.weight',
+ 'model.1.sub.16.RDB1.conv1.0.bias', 'model.1.sub.16.RDB1.conv2.0.weight',
+ 'model.1.sub.16.RDB1.conv2.0.bias', 'model.1.sub.16.RDB1.conv3.0.weight',
+ 'model.1.sub.16.RDB1.conv3.0.bias', 'model.1.sub.16.RDB1.conv4.0.weight',
+ 'model.1.sub.16.RDB1.conv4.0.bias', 'model.1.sub.16.RDB1.conv5.0.weight',
+ 'model.1.sub.16.RDB1.conv5.0.bias', 'model.1.sub.16.RDB2.conv1.0.weight',
+ 'model.1.sub.16.RDB2.conv1.0.bias', 'model.1.sub.16.RDB2.conv2.0.weight',
+ 'model.1.sub.16.RDB2.conv2.0.bias', 'model.1.sub.16.RDB2.conv3.0.weight',
+ 'model.1.sub.16.RDB2.conv3.0.bias', 'model.1.sub.16.RDB2.conv4.0.weight',
+ 'model.1.sub.16.RDB2.conv4.0.bias', 'model.1.sub.16.RDB2.conv5.0.weight',
+ 'model.1.sub.16.RDB2.conv5.0.bias', 'model.1.sub.16.RDB3.conv1.0.weight',
+ 'model.1.sub.16.RDB3.conv1.0.bias', 'model.1.sub.16.RDB3.conv2.0.weight',
+ 'model.1.sub.16.RDB3.conv2.0.bias', 'model.1.sub.16.RDB3.conv3.0.weight',
+ 'model.1.sub.16.RDB3.conv3.0.bias', 'model.1.sub.16.RDB3.conv4.0.weight',
+ 'model.1.sub.16.RDB3.conv4.0.bias', 'model.1.sub.16.RDB3.conv5.0.weight',
+ 'model.1.sub.16.RDB3.conv5.0.bias', 'model.1.sub.17.RDB1.conv1.0.weight',
+ 'model.1.sub.17.RDB1.conv1.0.bias', 'model.1.sub.17.RDB1.conv2.0.weight',
+ 'model.1.sub.17.RDB1.conv2.0.bias', 'model.1.sub.17.RDB1.conv3.0.weight',
+ 'model.1.sub.17.RDB1.conv3.0.bias', 'model.1.sub.17.RDB1.conv4.0.weight',
+ 'model.1.sub.17.RDB1.conv4.0.bias', 'model.1.sub.17.RDB1.conv5.0.weight',
+ 'model.1.sub.17.RDB1.conv5.0.bias', 'model.1.sub.17.RDB2.conv1.0.weight',
+ 'model.1.sub.17.RDB2.conv1.0.bias', 'model.1.sub.17.RDB2.conv2.0.weight',
+ 'model.1.sub.17.RDB2.conv2.0.bias', 'model.1.sub.17.RDB2.conv3.0.weight',
+ 'model.1.sub.17.RDB2.conv3.0.bias', 'model.1.sub.17.RDB2.conv4.0.weight',
+ 'model.1.sub.17.RDB2.conv4.0.bias', 'model.1.sub.17.RDB2.conv5.0.weight',
+ 'model.1.sub.17.RDB2.conv5.0.bias', 'model.1.sub.17.RDB3.conv1.0.weight',
+ 'model.1.sub.17.RDB3.conv1.0.bias', 'model.1.sub.17.RDB3.conv2.0.weight',
+ 'model.1.sub.17.RDB3.conv2.0.bias', 'model.1.sub.17.RDB3.conv3.0.weight',
+ 'model.1.sub.17.RDB3.conv3.0.bias', 'model.1.sub.17.RDB3.conv4.0.weight',
+ 'model.1.sub.17.RDB3.conv4.0.bias', 'model.1.sub.17.RDB3.conv5.0.weight',
+ 'model.1.sub.17.RDB3.conv5.0.bias', 'model.1.sub.18.RDB1.conv1.0.weight',
+ 'model.1.sub.18.RDB1.conv1.0.bias', 'model.1.sub.18.RDB1.conv2.0.weight',
+ 'model.1.sub.18.RDB1.conv2.0.bias', 'model.1.sub.18.RDB1.conv3.0.weight',
+ 'model.1.sub.18.RDB1.conv3.0.bias', 'model.1.sub.18.RDB1.conv4.0.weight',
+ 'model.1.sub.18.RDB1.conv4.0.bias', 'model.1.sub.18.RDB1.conv5.0.weight',
+ 'model.1.sub.18.RDB1.conv5.0.bias', 'model.1.sub.18.RDB2.conv1.0.weight',
+ 'model.1.sub.18.RDB2.conv1.0.bias', 'model.1.sub.18.RDB2.conv2.0.weight',
+ 'model.1.sub.18.RDB2.conv2.0.bias', 'model.1.sub.18.RDB2.conv3.0.weight',
+ 'model.1.sub.18.RDB2.conv3.0.bias', 'model.1.sub.18.RDB2.conv4.0.weight',
+ 'model.1.sub.18.RDB2.conv4.0.bias', 'model.1.sub.18.RDB2.conv5.0.weight',
+ 'model.1.sub.18.RDB2.conv5.0.bias', 'model.1.sub.18.RDB3.conv1.0.weight',
+ 'model.1.sub.18.RDB3.conv1.0.bias', 'model.1.sub.18.RDB3.conv2.0.weight',
+ 'model.1.sub.18.RDB3.conv2.0.bias', 'model.1.sub.18.RDB3.conv3.0.weight',
+ 'model.1.sub.18.RDB3.conv3.0.bias', 'model.1.sub.18.RDB3.conv4.0.weight',
+ 'model.1.sub.18.RDB3.conv4.0.bias', 'model.1.sub.18.RDB3.conv5.0.weight',
+ 'model.1.sub.18.RDB3.conv5.0.bias', 'model.1.sub.19.RDB1.conv1.0.weight',
+ 'model.1.sub.19.RDB1.conv1.0.bias', 'model.1.sub.19.RDB1.conv2.0.weight',
+ 'model.1.sub.19.RDB1.conv2.0.bias', 'model.1.sub.19.RDB1.conv3.0.weight',
+ 'model.1.sub.19.RDB1.conv3.0.bias', 'model.1.sub.19.RDB1.conv4.0.weight',
+ 'model.1.sub.19.RDB1.conv4.0.bias', 'model.1.sub.19.RDB1.conv5.0.weight',
+ 'model.1.sub.19.RDB1.conv5.0.bias', 'model.1.sub.19.RDB2.conv1.0.weight',
+ 'model.1.sub.19.RDB2.conv1.0.bias', 'model.1.sub.19.RDB2.conv2.0.weight',
+ 'model.1.sub.19.RDB2.conv2.0.bias', 'model.1.sub.19.RDB2.conv3.0.weight',
+ 'model.1.sub.19.RDB2.conv3.0.bias', 'model.1.sub.19.RDB2.conv4.0.weight',
+ 'model.1.sub.19.RDB2.conv4.0.bias', 'model.1.sub.19.RDB2.conv5.0.weight',
+ 'model.1.sub.19.RDB2.conv5.0.bias', 'model.1.sub.19.RDB3.conv1.0.weight',
+ 'model.1.sub.19.RDB3.conv1.0.bias', 'model.1.sub.19.RDB3.conv2.0.weight',
+ 'model.1.sub.19.RDB3.conv2.0.bias', 'model.1.sub.19.RDB3.conv3.0.weight',
+ 'model.1.sub.19.RDB3.conv3.0.bias', 'model.1.sub.19.RDB3.conv4.0.weight',
+ 'model.1.sub.19.RDB3.conv4.0.bias', 'model.1.sub.19.RDB3.conv5.0.weight',
+ 'model.1.sub.19.RDB3.conv5.0.bias', 'model.1.sub.20.RDB1.conv1.0.weight',
+ 'model.1.sub.20.RDB1.conv1.0.bias', 'model.1.sub.20.RDB1.conv2.0.weight',
+ 'model.1.sub.20.RDB1.conv2.0.bias', 'model.1.sub.20.RDB1.conv3.0.weight',
+ 'model.1.sub.20.RDB1.conv3.0.bias', 'model.1.sub.20.RDB1.conv4.0.weight',
+ 'model.1.sub.20.RDB1.conv4.0.bias', 'model.1.sub.20.RDB1.conv5.0.weight',
+ 'model.1.sub.20.RDB1.conv5.0.bias', 'model.1.sub.20.RDB2.conv1.0.weight',
+ 'model.1.sub.20.RDB2.conv1.0.bias', 'model.1.sub.20.RDB2.conv2.0.weight',
+ 'model.1.sub.20.RDB2.conv2.0.bias', 'model.1.sub.20.RDB2.conv3.0.weight',
+ 'model.1.sub.20.RDB2.conv3.0.bias', 'model.1.sub.20.RDB2.conv4.0.weight',
+ 'model.1.sub.20.RDB2.conv4.0.bias', 'model.1.sub.20.RDB2.conv5.0.weight',
+ 'model.1.sub.20.RDB2.conv5.0.bias', 'model.1.sub.20.RDB3.conv1.0.weight',
+ 'model.1.sub.20.RDB3.conv1.0.bias', 'model.1.sub.20.RDB3.conv2.0.weight',
+ 'model.1.sub.20.RDB3.conv2.0.bias', 'model.1.sub.20.RDB3.conv3.0.weight',
+ 'model.1.sub.20.RDB3.conv3.0.bias', 'model.1.sub.20.RDB3.conv4.0.weight',
+ 'model.1.sub.20.RDB3.conv4.0.bias', 'model.1.sub.20.RDB3.conv5.0.weight',
+ 'model.1.sub.20.RDB3.conv5.0.bias', 'model.1.sub.21.RDB1.conv1.0.weight',
+ 'model.1.sub.21.RDB1.conv1.0.bias', 'model.1.sub.21.RDB1.conv2.0.weight',
+ 'model.1.sub.21.RDB1.conv2.0.bias', 'model.1.sub.21.RDB1.conv3.0.weight',
+ 'model.1.sub.21.RDB1.conv3.0.bias', 'model.1.sub.21.RDB1.conv4.0.weight',
+ 'model.1.sub.21.RDB1.conv4.0.bias', 'model.1.sub.21.RDB1.conv5.0.weight',
+ 'model.1.sub.21.RDB1.conv5.0.bias', 'model.1.sub.21.RDB2.conv1.0.weight',
+ 'model.1.sub.21.RDB2.conv1.0.bias', 'model.1.sub.21.RDB2.conv2.0.weight',
+ 'model.1.sub.21.RDB2.conv2.0.bias', 'model.1.sub.21.RDB2.conv3.0.weight',
+ 'model.1.sub.21.RDB2.conv3.0.bias', 'model.1.sub.21.RDB2.conv4.0.weight',
+ 'model.1.sub.21.RDB2.conv4.0.bias', 'model.1.sub.21.RDB2.conv5.0.weight',
+ 'model.1.sub.21.RDB2.conv5.0.bias', 'model.1.sub.21.RDB3.conv1.0.weight',
+ 'model.1.sub.21.RDB3.conv1.0.bias', 'model.1.sub.21.RDB3.conv2.0.weight',
+ 'model.1.sub.21.RDB3.conv2.0.bias', 'model.1.sub.21.RDB3.conv3.0.weight',
+ 'model.1.sub.21.RDB3.conv3.0.bias', 'model.1.sub.21.RDB3.conv4.0.weight',
+ 'model.1.sub.21.RDB3.conv4.0.bias', 'model.1.sub.21.RDB3.conv5.0.weight',
+ 'model.1.sub.21.RDB3.conv5.0.bias', 'model.1.sub.22.RDB1.conv1.0.weight',
+ 'model.1.sub.22.RDB1.conv1.0.bias', 'model.1.sub.22.RDB1.conv2.0.weight',
+ 'model.1.sub.22.RDB1.conv2.0.bias', 'model.1.sub.22.RDB1.conv3.0.weight',
+ 'model.1.sub.22.RDB1.conv3.0.bias', 'model.1.sub.22.RDB1.conv4.0.weight',
+ 'model.1.sub.22.RDB1.conv4.0.bias', 'model.1.sub.22.RDB1.conv5.0.weight',
+ 'model.1.sub.22.RDB1.conv5.0.bias', 'model.1.sub.22.RDB2.conv1.0.weight',
+ 'model.1.sub.22.RDB2.conv1.0.bias', 'model.1.sub.22.RDB2.conv2.0.weight',
+ 'model.1.sub.22.RDB2.conv2.0.bias', 'model.1.sub.22.RDB2.conv3.0.weight',
+ 'model.1.sub.22.RDB2.conv3.0.bias', 'model.1.sub.22.RDB2.conv4.0.weight',
+ 'model.1.sub.22.RDB2.conv4.0.bias', 'model.1.sub.22.RDB2.conv5.0.weight',
+ 'model.1.sub.22.RDB2.conv5.0.bias', 'model.1.sub.22.RDB3.conv1.0.weight',
+ 'model.1.sub.22.RDB3.conv1.0.bias', 'model.1.sub.22.RDB3.conv2.0.weight',
+ 'model.1.sub.22.RDB3.conv2.0.bias', 'model.1.sub.22.RDB3.conv3.0.weight',
+ 'model.1.sub.22.RDB3.conv3.0.bias', 'model.1.sub.22.RDB3.conv4.0.weight',
+ 'model.1.sub.22.RDB3.conv4.0.bias', 'model.1.sub.22.RDB3.conv5.0.weight',
+ 'model.1.sub.22.RDB3.conv5.0.bias', 'model.1.sub.23.weight', 'model.1.sub.23.bias',
+ 'model.3.weight', 'model.3.bias', 'model.6.weight', 'model.6.bias', 'model.8.weight',
+ 'model.8.bias', 'model.10.weight', 'model.10.bias']
+
+
+# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
+# Which enhanced stuff that was already here
+class RRDBNet(nn.Module):
+ def __init__(
+ self,
+ state_dict,
+ norm=None,
+ act: str = "leakyrelu",
+ upsampler: str = "upconv",
+ mode: B.ConvMode = "CNA",
+ ) -> None:
+ """
+ ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
+ By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
+ and Chen Change Loy.
+ This is old-arch Residual in Residual Dense Block Network and is not
+ the newest revision that's available at github.com/xinntao/ESRGAN.
+ This is on purpose, the newest Network has severely limited the
+ potential use of the Network with no benefits.
+ This network supports model files from both new and old-arch.
+ Args:
+ norm: Normalization layer
+ act: Activation layer
+ upsampler: Upsample layer. upconv, pixel_shuffle
+ mode: Convolution mode
+ """
+ super(RRDBNet, self).__init__()
+ self.model_arch = "ESRGAN"
+ self.sub_type = "SR"
+
+ self.state = state_dict
+ self.norm = norm
+ self.act = act
+ self.upsampler = upsampler
+ self.mode = mode
+
+ self.state_map = {
+ # currently supports old, new, and newer RRDBNet arch models
+ # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
+ "model.0.weight": ("conv_first.weight",),
+ "model.0.bias": ("conv_first.bias",),
+ "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
+ "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
+ r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
+ r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
+ r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
+ ),
+ }
+ if "params_ema" in self.state:
+ self.state = self.state["params_ema"]
+ # self.model_arch = "RealESRGAN"
+ self.num_blocks = self.get_num_blocks()
+ self.plus = any("conv1x1" in k for k in self.state.keys())
+ if self.plus:
+ self.model_arch = "ESRGAN+"
+
+ self.state = self.new_to_old_arch(self.state)
+
+ self.key_arr = list(self.state.keys())
+
+ self.in_nc: int = self.state[self.key_arr[0]].shape[1]
+ self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
+
+ self.scale: int = self.get_scale()
+ self.num_filters: int = self.state[self.key_arr[0]].shape[0]
+
+ c2x2 = False
+ if self.state["model.0.weight"].shape[-2] == 2:
+ c2x2 = True
+ self.scale = round(math.sqrt(self.scale / 4))
+ self.model_arch = "ESRGAN-2c2"
+
+ self.supports_fp16 = True
+ self.supports_bfp16 = True
+ self.min_size_restriction = None
+
+ # Detect if pixelunshuffle was used (Real-ESRGAN)
+ if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
+ self.in_nc / 4,
+ self.in_nc / 16,
+ ):
+ self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
+ else:
+ self.shuffle_factor = None
+
+ upsample_block = {
+ "upconv": B.upconv_block,
+ "pixel_shuffle": B.pixelshuffle_block,
+ }.get(self.upsampler)
+ if upsample_block is None:
+ raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
+
+ if self.scale == 3:
+ upsample_blocks = upsample_block(
+ in_nc=self.num_filters,
+ out_nc=self.num_filters,
+ upscale_factor=3,
+ act_type=self.act,
+ c2x2=c2x2,
+ )
+ else:
+ upsample_blocks = [
+ upsample_block(
+ in_nc=self.num_filters,
+ out_nc=self.num_filters,
+ act_type=self.act,
+ c2x2=c2x2,
+ )
+ for _ in range(int(math.log(self.scale, 2)))
+ ]
+
+ self.model = B.sequential(
+ # fea conv
+ B.conv_block(
+ in_nc=self.in_nc,
+ out_nc=self.num_filters,
+ kernel_size=3,
+ norm_type=None,
+ act_type=None,
+ c2x2=c2x2,
+ ),
+ B.ShortcutBlock(
+ B.sequential(
+ # rrdb blocks
+ *[
+ B.RRDB(
+ nf=self.num_filters,
+ kernel_size=3,
+ gc=32,
+ stride=1,
+ bias=True,
+ pad_type="zero",
+ norm_type=self.norm,
+ act_type=self.act,
+ mode="CNA",
+ plus=self.plus,
+ c2x2=c2x2,
+ )
+ for _ in range(self.num_blocks)
+ ],
+ # lr conv
+ B.conv_block(
+ in_nc=self.num_filters,
+ out_nc=self.num_filters,
+ kernel_size=3,
+ norm_type=self.norm,
+ act_type=None,
+ mode=self.mode,
+ c2x2=c2x2,
+ ),
+ )
+ ),
+ *upsample_blocks,
+ # hr_conv0
+ B.conv_block(
+ in_nc=self.num_filters,
+ out_nc=self.num_filters,
+ kernel_size=3,
+ norm_type=None,
+ act_type=self.act,
+ c2x2=c2x2,
+ ),
+ # hr_conv1
+ B.conv_block(
+ in_nc=self.num_filters,
+ out_nc=self.out_nc,
+ kernel_size=3,
+ norm_type=None,
+ act_type=None,
+ c2x2=c2x2,
+ ),
+ )
+
+ # Adjust these properties for calculations outside of the model
+ if self.shuffle_factor:
+ self.in_nc //= self.shuffle_factor ** 2
+ self.scale //= self.shuffle_factor
+
+ self.load_state_dict(self.state, strict=False)
+
+ def new_to_old_arch(self, state):
+ """Convert a new-arch model state dictionary to an old-arch dictionary."""
+ if "params_ema" in state:
+ state = state["params_ema"]
+
+ if "conv_first.weight" not in state:
+ # model is already old arch, this is a loose check, but should be sufficient
+ return state
+
+ # add nb to state keys
+ for kind in ("weight", "bias"):
+ self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
+ f"model.1.sub./NB/.{kind}"
+ ]
+ del self.state_map[f"model.1.sub./NB/.{kind}"]
+
+ old_state = OrderedDict()
+ for old_key, new_keys in self.state_map.items():
+ for new_key in new_keys:
+ if r"\1" in old_key:
+ for k, v in state.items():
+ sub = re.sub(new_key, old_key, k)
+ if sub != k:
+ old_state[sub] = v
+ else:
+ if new_key in state:
+ old_state[old_key] = state[new_key]
+
+ # upconv layers
+ max_upconv = 0
+ for key in state.keys():
+ match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
+ if match is not None:
+ _, key_num, key_type = match.groups()
+ old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
+ max_upconv = max(max_upconv, int(key_num) * 3)
+
+ # final layers
+ for key in state.keys():
+ if key in ("HRconv.weight", "conv_hr.weight"):
+ old_state[f"model.{max_upconv + 2}.weight"] = state[key]
+ elif key in ("HRconv.bias", "conv_hr.bias"):
+ old_state[f"model.{max_upconv + 2}.bias"] = state[key]
+ elif key in ("conv_last.weight",):
+ old_state[f"model.{max_upconv + 4}.weight"] = state[key]
+ elif key in ("conv_last.bias",):
+ old_state[f"model.{max_upconv + 4}.bias"] = state[key]
+
+ # Sort by first numeric value of each layer
+ def compare(item1, item2):
+ parts1 = item1.split(".")
+ parts2 = item2.split(".")
+ int1 = int(parts1[1])
+ int2 = int(parts2[1])
+ return int1 - int2
+
+ sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
+
+ # Rebuild the output dict in the right order
+ out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
+
+ return out_dict
+
+ def get_scale(self, min_part: int = 6) -> int:
+ n = 0
+ for part in list(self.state):
+ parts = part.split(".")[1:]
+ if len(parts) == 2:
+ part_num = int(parts[0])
+ if part_num > min_part and parts[1] == "weight":
+ n += 1
+ return 2 ** n
+
+ def get_num_blocks(self) -> int:
+ nbs = []
+ state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
+ r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
+ )
+ for state_key in state_keys:
+ for k in self.state:
+ m = re.search(state_key, k)
+ if m:
+ nbs.append(int(m.group(1)))
+ if nbs:
+ break
+ return max(*nbs) + 1
+
+ def forward(self, x):
+ if self.shuffle_factor:
+ _, _, h, w = x.size()
+ mod_pad_h = (
+ self.shuffle_factor - h % self.shuffle_factor
+ ) % self.shuffle_factor
+ mod_pad_w = (
+ self.shuffle_factor - w % self.shuffle_factor
+ ) % self.shuffle_factor
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
+ x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
+ x = self.model(x)
+ return x[:, :, : h * self.scale, : w * self.scale]
+ return self.model(x)
diff --git a/toolkit/models/auraflow.py b/toolkit/models/auraflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2539bda489ccc1975f42b9c9a027076f8fdfc74
--- /dev/null
+++ b/toolkit/models/auraflow.py
@@ -0,0 +1,127 @@
+import math
+from functools import partial
+
+from torch import nn
+import torch
+
+
+class AuraFlowPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ pos_embed_max_size=None,
+ ):
+ super().__init__()
+
+ self.num_patches = (height // patch_size) * (width // patch_size)
+ self.pos_embed_max_size = pos_embed_max_size
+
+ self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
+ self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
+
+ self.patch_size = patch_size
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+
+ def forward(self, latent):
+ batch_size, num_channels, height, width = latent.size()
+ latent = latent.view(
+ batch_size,
+ num_channels,
+ height // self.patch_size,
+ self.patch_size,
+ width // self.patch_size,
+ self.patch_size,
+ )
+ latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
+ latent = self.proj(latent)
+ try:
+ return latent + self.pos_embed
+ except RuntimeError:
+ raise RuntimeError(
+ f"Positional embeddings are too small for the number of patches. "
+ f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
+ )
+
+
+# comfy
+# def apply_pos_embeds(self, x, h, w):
+# h = (h + 1) // self.patch_size
+# w = (w + 1) // self.patch_size
+# max_dim = max(h, w)
+#
+# cur_dim = self.h_max
+# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
+#
+# if max_dim > cur_dim:
+# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1,
+# -1)
+# cur_dim = max_dim
+#
+# from_h = (cur_dim - h) // 2
+# from_w = (cur_dim - w) // 2
+# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w]
+# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
+
+ # def patchify(self, x):
+ # B, C, H, W = x.size()
+ # pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
+ # pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
+ #
+ # x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
+ # x = x.view(
+ # B,
+ # C,
+ # (H + 1) // self.patch_size,
+ # self.patch_size,
+ # (W + 1) // self.patch_size,
+ # self.patch_size,
+ # )
+ # x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
+ # return x
+
+def patch_auraflow_pos_embed(pos_embed):
+ # we need to hijack the forward and replace with a custom one. Self is the model
+ def new_forward(self, latent):
+ batch_size, num_channels, height, width = latent.size()
+
+ # add padding to the latent to make it match pos_embed
+ latent_size = height * width * num_channels / 16 # todo check where 16 comes from?
+ pos_embed_size = self.pos_embed.shape[1]
+ if latent_size < pos_embed_size:
+ total_padding = int(pos_embed_size - math.floor(latent_size))
+ total_padding = total_padding // 16
+ pad_height = total_padding // 2
+ pad_width = total_padding - pad_height
+ # mirror padding on the right side
+ padding = (0, pad_width, 0, pad_height)
+ latent = torch.nn.functional.pad(latent, padding, mode='reflect')
+ elif latent_size > pos_embed_size:
+ amount_to_remove = latent_size - pos_embed_size
+ latent = latent[:, :, :-amount_to_remove]
+
+ batch_size, num_channels, height, width = latent.size()
+
+ latent = latent.view(
+ batch_size,
+ num_channels,
+ height // self.patch_size,
+ self.patch_size,
+ width // self.patch_size,
+ self.patch_size,
+ )
+ latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
+ latent = self.proj(latent)
+ try:
+ return latent + self.pos_embed
+ except RuntimeError:
+ raise RuntimeError(
+ f"Positional embeddings are too small for the number of patches. "
+ f"Please increase `pos_embed_max_size` to at least {self.num_patches}."
+ )
+
+ pos_embed.forward = partial(new_forward, pos_embed)
diff --git a/toolkit/models/block.py b/toolkit/models/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..76356b5e3eb7c7d6dc4ed1629aac318c264111c5
--- /dev/null
+++ b/toolkit/models/block.py
@@ -0,0 +1,549 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from __future__ import annotations
+
+from collections import OrderedDict
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+import torch
+import torch.nn as nn
+
+
+####################
+# Basic blocks
+####################
+
+
+def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
+ # helper selecting activation
+ # neg_slope: for leakyrelu and init of prelu
+ # n_prelu: for p_relu num_parameters
+ act_type = act_type.lower()
+ if act_type == "relu":
+ layer = nn.ReLU(inplace)
+ elif act_type == "leakyrelu":
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == "prelu":
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ else:
+ raise NotImplementedError(
+ "activation layer [{:s}] is not found".format(act_type)
+ )
+ return layer
+
+
+def norm(norm_type: str, nc: int):
+ # helper selecting normalization layer
+ norm_type = norm_type.lower()
+ if norm_type == "batch":
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == "instance":
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ else:
+ raise NotImplementedError(
+ "normalization layer [{:s}] is not found".format(norm_type)
+ )
+ return layer
+
+
+def pad(pad_type: str, padding):
+ # helper selecting padding layer
+ # if padding is 'zero', do by conv layers
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == "reflect":
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == "replicate":
+ layer = nn.ReplicationPad2d(padding)
+ else:
+ raise NotImplementedError(
+ "padding layer [{:s}] is not implemented".format(pad_type)
+ )
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ConcatBlock(nn.Module):
+ # Concat the output of a submodule to its input
+ def __init__(self, submodule):
+ super(ConcatBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = torch.cat((x, self.sub(x)), dim=1)
+ return output
+
+ def __repr__(self):
+ tmpstr = "Identity .. \n|"
+ modstr = self.sub.__repr__().replace("\n", "\n|")
+ tmpstr = tmpstr + modstr
+ return tmpstr
+
+
+class ShortcutBlock(nn.Module):
+ # Elementwise sum the output of a submodule to its input
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ tmpstr = "Identity + \n|"
+ modstr = self.sub.__repr__().replace("\n", "\n|")
+ tmpstr = tmpstr + modstr
+ return tmpstr
+
+
+class ShortcutBlockSPSR(nn.Module):
+ # Elementwise sum the output of a submodule to its input
+ def __init__(self, submodule):
+ super(ShortcutBlockSPSR, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ return x, self.sub
+
+ def __repr__(self):
+ tmpstr = "Identity + \n|"
+ modstr = self.sub.__repr__().replace("\n", "\n|")
+ tmpstr = tmpstr + modstr
+ return tmpstr
+
+
+def sequential(*args):
+ # Flatten Sequential. It unwraps nn.Sequential.
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError("sequential does not support OrderedDict input.")
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+ConvMode = Literal["CNA", "NAC", "CNAC"]
+
+
+# 2x2x2 Conv Block
+def conv_block_2c2(
+ in_nc,
+ out_nc,
+ act_type="relu",
+):
+ return sequential(
+ nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
+ nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
+ act(act_type) if act_type else None,
+ )
+
+
+def conv_block(
+ in_nc: int,
+ out_nc: int,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ groups=1,
+ bias=True,
+ pad_type="zero",
+ norm_type: str | None = None,
+ act_type: str | None = "relu",
+ mode: ConvMode = "CNA",
+ c2x2=False,
+):
+ """
+ Conv layer with padding, normalization, activation
+ mode: CNA --> Conv -> Norm -> Act
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
+ """
+
+ if c2x2:
+ return conv_block_2c2(in_nc, out_nc, act_type=act_type)
+
+ assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
+ padding = padding if pad_type == "zero" else 0
+
+ c = nn.Conv2d(
+ in_nc,
+ out_nc,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ groups=groups,
+ )
+ a = act(act_type) if act_type else None
+ if mode in ("CNA", "CNAC"):
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == "NAC":
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ # Important!
+ # input----ReLU(inplace)----Conv--+----output
+ # |________________________|
+ # inplace ReLU will modify the input, therefore wrong output
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
+ else:
+ assert False, f"Invalid conv mode {mode}"
+
+
+####################
+# Useful blocks
+####################
+
+
+class ResNetBlock(nn.Module):
+ """
+ ResNet Block, 3-3 style
+ with extra residual scaling used in EDSR
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
+ """
+
+ def __init__(
+ self,
+ in_nc,
+ mid_nc,
+ out_nc,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ groups=1,
+ bias=True,
+ pad_type="zero",
+ norm_type=None,
+ act_type="relu",
+ mode: ConvMode = "CNA",
+ res_scale=1,
+ ):
+ super(ResNetBlock, self).__init__()
+ conv0 = conv_block(
+ in_nc,
+ mid_nc,
+ kernel_size,
+ stride,
+ dilation,
+ groups,
+ bias,
+ pad_type,
+ norm_type,
+ act_type,
+ mode,
+ )
+ if mode == "CNA":
+ act_type = None
+ if mode == "CNAC": # Residual path: |-CNAC-|
+ act_type = None
+ norm_type = None
+ conv1 = conv_block(
+ mid_nc,
+ out_nc,
+ kernel_size,
+ stride,
+ dilation,
+ groups,
+ bias,
+ pad_type,
+ norm_type,
+ act_type,
+ mode,
+ )
+ # if in_nc != out_nc:
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
+ # None, None)
+ # print('Need a projecter in ResNetBlock.')
+ # else:
+ # self.project = lambda x:x
+ self.res = sequential(conv0, conv1)
+ self.res_scale = res_scale
+
+ def forward(self, x):
+ res = self.res(x).mul(self.res_scale)
+ return x + res
+
+
+class RRDB(nn.Module):
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
+
+ def __init__(
+ self,
+ nf,
+ kernel_size=3,
+ gc=32,
+ stride=1,
+ bias: bool = True,
+ pad_type="zero",
+ norm_type=None,
+ act_type="leakyrelu",
+ mode: ConvMode = "CNA",
+ _convtype="Conv2D",
+ _spectral_norm=False,
+ plus=False,
+ c2x2=False,
+ ):
+ super(RRDB, self).__init__()
+ self.RDB1 = ResidualDenseBlock_5C(
+ nf,
+ kernel_size,
+ gc,
+ stride,
+ bias,
+ pad_type,
+ norm_type,
+ act_type,
+ mode,
+ plus=plus,
+ c2x2=c2x2,
+ )
+ self.RDB2 = ResidualDenseBlock_5C(
+ nf,
+ kernel_size,
+ gc,
+ stride,
+ bias,
+ pad_type,
+ norm_type,
+ act_type,
+ mode,
+ plus=plus,
+ c2x2=c2x2,
+ )
+ self.RDB3 = ResidualDenseBlock_5C(
+ nf,
+ kernel_size,
+ gc,
+ stride,
+ bias,
+ pad_type,
+ norm_type,
+ act_type,
+ mode,
+ plus=plus,
+ c2x2=c2x2,
+ )
+
+ def forward(self, x):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ return out * 0.2 + x
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ style: 5 convs
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+
+ Args:
+ nf (int): Channel number of intermediate features (num_feat).
+ gc (int): Channels for each growth (num_grow_ch: growth channel,
+ i.e. intermediate channels).
+ convtype (str): the type of convolution to use. Default: 'Conv2D'
+ gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
+ trainable parameters)
+ plus (bool): enable the additional residual paths from ESRGAN+
+ (adds trainable parameters)
+ """
+
+ def __init__(
+ self,
+ nf=64,
+ kernel_size=3,
+ gc=32,
+ stride=1,
+ bias: bool = True,
+ pad_type="zero",
+ norm_type=None,
+ act_type="leakyrelu",
+ mode: ConvMode = "CNA",
+ plus=False,
+ c2x2=False,
+ ):
+ super(ResidualDenseBlock_5C, self).__init__()
+
+ ## +
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
+ ## +
+
+ self.conv1 = conv_block(
+ nf,
+ gc,
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=act_type,
+ mode=mode,
+ c2x2=c2x2,
+ )
+ self.conv2 = conv_block(
+ nf + gc,
+ gc,
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=act_type,
+ mode=mode,
+ c2x2=c2x2,
+ )
+ self.conv3 = conv_block(
+ nf + 2 * gc,
+ gc,
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=act_type,
+ mode=mode,
+ c2x2=c2x2,
+ )
+ self.conv4 = conv_block(
+ nf + 3 * gc,
+ gc,
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=act_type,
+ mode=mode,
+ c2x2=c2x2,
+ )
+ if mode == "CNA":
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(
+ nf + 4 * gc,
+ nf,
+ 3,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=last_act,
+ mode=mode,
+ c2x2=c2x2,
+ )
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ # pylint: disable=not-callable
+ x2 = x2 + self.conv1x1(x) # +
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2 # +
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# Upsampler
+####################
+
+
+def pixelshuffle_block(
+ in_nc: int,
+ out_nc: int,
+ upscale_factor=2,
+ kernel_size=3,
+ stride=1,
+ bias=True,
+ pad_type="zero",
+ norm_type: str | None = None,
+ act_type="relu",
+):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(
+ in_nc,
+ out_nc * (upscale_factor ** 2),
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=None,
+ act_type=None,
+ )
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(
+ in_nc: int,
+ out_nc: int,
+ upscale_factor=2,
+ kernel_size=3,
+ stride=1,
+ bias=True,
+ pad_type="zero",
+ norm_type: str | None = None,
+ act_type="relu",
+ mode="nearest",
+ c2x2=False,
+):
+ # Up conv
+ # described in https://distill.pub/2016/deconv-checkerboard/
+ # convert to float 16 if is bfloat16
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(
+ in_nc,
+ out_nc,
+ kernel_size,
+ stride,
+ bias=bias,
+ pad_type=pad_type,
+ norm_type=norm_type,
+ act_type=act_type,
+ c2x2=c2x2,
+ )
+ return sequential(upsample, conv)
diff --git a/toolkit/models/clip_fusion.py b/toolkit/models/clip_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4346fd5ac3eae4c8d91e50df586acc8d4cd2fbe
--- /dev/null
+++ b/toolkit/models/clip_fusion.py
@@ -0,0 +1,162 @@
+import torch
+import torch.nn as nn
+
+from toolkit.models.zipper_resampler import ContextualAlphaMask
+
+
+# Conv1d MLP
+# MLP that can alternately be used as a conv1d on dim 1
+class MLPC(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ hidden_dim,
+ do_conv=False,
+ use_residual=True
+ ):
+ super().__init__()
+ self.do_conv = do_conv
+ if use_residual:
+ assert in_dim == out_dim
+ # dont normalize if using conv
+ if not do_conv:
+ self.layernorm = nn.LayerNorm(in_dim)
+
+ if do_conv:
+ self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1)
+ self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1)
+ else:
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+
+ self.use_residual = use_residual
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ residual = x
+ if not self.do_conv:
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ if self.use_residual:
+ x = x + residual
+ return x
+
+
+class ZipperBlock(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ in_tokens,
+ out_size,
+ out_tokens,
+ hidden_size,
+ hidden_tokens,
+ ):
+ super().__init__()
+ self.in_size = in_size
+ self.in_tokens = in_tokens
+ self.out_size = out_size
+ self.out_tokens = out_tokens
+ self.hidden_size = hidden_size
+ self.hidden_tokens = hidden_tokens
+ # permute to (batch_size, out_size, in_tokens)
+
+ self.zip_token = MLPC(
+ in_dim=self.in_tokens,
+ out_dim=self.out_tokens,
+ hidden_dim=self.hidden_tokens,
+ do_conv=True, # no need to permute
+ use_residual=False
+ )
+
+ # permute to (batch_size, out_tokens, out_size)
+
+ # in shpae: (batch_size, in_tokens, in_size)
+ self.zip_size = MLPC(
+ in_dim=self.in_size,
+ out_dim=self.out_size,
+ hidden_dim=self.hidden_size,
+ use_residual=False
+ )
+
+ def forward(self, x):
+ x = self.zip_token(x)
+ x = self.zip_size(x)
+ return x
+
+
+
+
+
+
+# CLIPFusionModule
+# Fuses any size of vision and text embeddings into a single embedding.
+# remaps tokens and vectors.
+class CLIPFusionModule(nn.Module):
+ def __init__(
+ self,
+ text_hidden_size: int = 768,
+ text_tokens: int = 77,
+ vision_hidden_size: int = 1024,
+ vision_tokens: int = 257,
+ num_blocks: int = 1,
+ ):
+ super(CLIPFusionModule, self).__init__()
+
+ self.text_hidden_size = text_hidden_size
+ self.text_tokens = text_tokens
+ self.vision_hidden_size = vision_hidden_size
+ self.vision_tokens = vision_tokens
+
+ self.resampler = ZipperBlock(
+ in_size=self.vision_hidden_size,
+ in_tokens=self.vision_tokens,
+ out_size=self.text_hidden_size,
+ out_tokens=self.text_tokens,
+ hidden_size=self.vision_hidden_size * 2,
+ hidden_tokens=self.vision_tokens * 2
+ )
+
+ self.zipper_blocks = torch.nn.ModuleList([
+ ZipperBlock(
+ in_size=self.text_hidden_size * 2,
+ in_tokens=self.text_tokens,
+ out_size=self.text_hidden_size,
+ out_tokens=self.text_tokens,
+ hidden_size=self.text_hidden_size * 2,
+ hidden_tokens=self.text_tokens * 2
+ ) for i in range(num_blocks)
+ ])
+
+ self.ctx_alpha = ContextualAlphaMask(
+ dim=self.text_hidden_size,
+ )
+
+ self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01)
+
+ def forward(self, text_embeds, vision_embeds):
+ # text_embeds = (batch_size, 77, 768)
+ # vision_embeds = (batch_size, 257, 1024)
+ # output = (batch_size, 77, 768)
+
+ vision_embeds = self.resampler(vision_embeds)
+ x = vision_embeds
+ for i, block in enumerate(self.zipper_blocks):
+ res = x
+ x = torch.cat([text_embeds, x], dim=-1)
+ x = block(x)
+ x = x + res
+
+ # alpha mask
+ ctx_alpha = self.ctx_alpha(text_embeds)
+ # reshape alpha to (1, 77, 1)
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
+
+ x = ctx_alpha * x * alpha
+
+ x = x + text_embeds
+
+ return x
diff --git a/toolkit/models/clip_pre_processor.py b/toolkit/models/clip_pre_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..7956da0b4d0a5d6b882d4d19d9458bf409cc9b39
--- /dev/null
+++ b/toolkit/models/clip_pre_processor.py
@@ -0,0 +1,123 @@
+import torch
+import torch.nn as nn
+
+
+class UpsampleBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.conv_in = nn.Sequential(
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
+ nn.GELU()
+ )
+ self.conv_up = nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.GELU()
+ )
+
+ self.conv_out = nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ x = self.conv_up(x)
+ x = self.conv_out(x)
+ return x
+
+
+class CLIPImagePreProcessor(nn.Module):
+ def __init__(
+ self,
+ input_size=896,
+ clip_input_size=224,
+ downscale_factor: int = 16,
+ ):
+ super().__init__()
+ # make sure they are evenly divisible
+ assert input_size % clip_input_size == 0
+ in_channels = 3
+
+ self.input_size = input_size
+ self.clip_input_size = clip_input_size
+ self.downscale_factor = downscale_factor
+
+ subpixel_channels = in_channels * downscale_factor ** 2 # 3 * 16 ** 2 = 768
+ channels = subpixel_channels
+
+ upscale_factor = downscale_factor / int((input_size / clip_input_size)) # 16 / (896 / 224) = 4
+
+ num_upsample_blocks = int(upscale_factor // 2) # 4 // 2 = 2
+
+ # make the residual down up blocks
+ self.upsample_blocks = nn.ModuleList()
+ self.subpixel_blocks = nn.ModuleList()
+ current_channels = channels
+ current_downscale = downscale_factor
+ for _ in range(num_upsample_blocks):
+ # determine the reshuffled channel count for this dimension
+ output_downscale = current_downscale // 2
+ out_channels = in_channels * output_downscale ** 2
+ # out_channels = current_channels // 2
+ self.upsample_blocks.append(UpsampleBlock(current_channels, out_channels))
+ current_channels = out_channels
+ current_downscale = output_downscale
+ self.subpixel_blocks.append(nn.PixelUnshuffle(current_downscale))
+
+ # (bs, 768, 56, 56) -> (bs, 192, 112, 112)
+ # (bs, 192, 112, 112) -> (bs, 48, 224, 224)
+
+ self.conv_out = nn.Conv2d(
+ current_channels,
+ out_channels=3,
+ kernel_size=3,
+ padding=1
+ ) # (bs, 48, 224, 224) -> (bs, 3, 224, 224)
+
+ # do a pooling layer to downscale the input to 1/3 of the size
+ # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
+ kernel_size = input_size // clip_input_size
+ self.res_down = nn.AvgPool2d(
+ kernel_size=kernel_size,
+ stride=kernel_size
+ ) # (bs, 3, 896, 896) -> (bs, 3, 224, 224)
+
+ # make a blending for output residual with near 0 weight
+ self.res_blend = nn.Parameter(torch.tensor(0.001)) # (bs, 3, 224, 224) -> (bs, 3, 224, 224)
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 896, 896) -> (bs, 768, 56, 56)
+
+ self.conv_in = nn.Sequential(
+ nn.Conv2d(
+ subpixel_channels,
+ channels,
+ kernel_size=3,
+ padding=1
+ ),
+ nn.GELU()
+ ) # (bs, 768, 56, 56) -> (bs, 768, 56, 56)
+
+ # make 2 deep blocks
+
+ def forward(self, x):
+ inputs = x
+ # resize to input_size x input_size
+ x = nn.functional.interpolate(x, size=(self.input_size, self.input_size), mode='bicubic')
+
+ res = self.res_down(inputs)
+
+ x = self.unshuffle(x)
+ x = self.conv_in(x)
+ for up, subpixel in zip(self.upsample_blocks, self.subpixel_blocks):
+ x = up(x)
+ block_res = subpixel(inputs)
+ x = x + block_res
+ x = self.conv_out(x)
+ # blend residual
+ x = x * self.res_blend + res
+ return x
diff --git a/toolkit/models/decorator.py b/toolkit/models/decorator.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f45aa9f944370727eed1a362c9bb04ad99fa9b
--- /dev/null
+++ b/toolkit/models/decorator.py
@@ -0,0 +1,33 @@
+import torch
+
+
+class Decorator(torch.nn.Module):
+ def __init__(
+ self,
+ num_tokens: int = 4,
+ token_size: int = 4096,
+ ) -> None:
+ super().__init__()
+
+ self.weight: torch.nn.Parameter = torch.nn.Parameter(
+ torch.randn(num_tokens, token_size)
+ )
+ # ensure it is float32
+ self.weight.data = self.weight.data.float()
+
+ def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor:
+ # make sure the param is float32
+ if self.weight.dtype != text_embeds.dtype:
+ self.weight.data = self.weight.data.float()
+ # expand batch to match text_embeds
+ batch_size = text_embeds.shape[0]
+ decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1)
+ if is_unconditional:
+ # zero pad the decorator embeds
+ decorator_embeds = torch.zeros_like(decorator_embeds)
+
+ if decorator_embeds.dtype != text_embeds.dtype:
+ decorator_embeds = decorator_embeds.to(text_embeds.dtype)
+ text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2)
+
+ return text_embeds
diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ce8786ca86c77c01d73fb6ff8875056ee9d4bc
--- /dev/null
+++ b/toolkit/models/flux.py
@@ -0,0 +1,35 @@
+
+# forward that bypasses the guidance embedding so it can be avoided during training.
+from functools import partial
+
+
+def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(
+ timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = timesteps_emb + pooled_projections
+ return conditioning
+
+# bypass the forward function
+
+
+def bypass_flux_guidance(transformer):
+ if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
+ return
+ # dont bypass if it doesnt have the guidance embedding
+ if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
+ return
+ transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
+ transformer.time_text_embed.forward = partial(
+ guidance_embed_bypass_forward, transformer.time_text_embed
+ )
+
+# restore the forward function
+
+
+def restore_flux_guidance(transformer):
+ if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
+ return
+ transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
+ del transformer.time_text_embed._bfg_orig_forward
diff --git a/toolkit/models/flux_sage_attn.py b/toolkit/models/flux_sage_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..930a17000c92e7bef14d141c22e9acfef8f92bf4
--- /dev/null
+++ b/toolkit/models/flux_sage_attn.py
@@ -0,0 +1,94 @@
+from typing import Optional
+from diffusers.models.attention_processor import Attention
+import torch
+import torch.nn.functional as F
+
+
+class FluxSageAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ from sageattention import sageattn
+
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
\ No newline at end of file
diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py
new file mode 100644
index 0000000000000000000000000000000000000000..33613ed3193249ebefdea4fc3ff470bdd12a5a27
--- /dev/null
+++ b/toolkit/models/ilora.py
@@ -0,0 +1,364 @@
+import math
+import weakref
+
+import torch
+import torch.nn as nn
+from typing import TYPE_CHECKING, List, Dict, Any
+from toolkit.models.clip_fusion import ZipperBlock
+from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
+import sys
+from toolkit.paths import REPOS_ROOT
+sys.path.append(REPOS_ROOT)
+from ipadapter.ip_adapter.resampler import Resampler
+from collections import OrderedDict
+
+if TYPE_CHECKING:
+ from toolkit.lora_special import LoRAModule
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+class MLP(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
+ super().__init__()
+ if use_residual:
+ assert in_dim == out_dim
+ self.layernorm = nn.LayerNorm(in_dim)
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.use_residual = use_residual
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ residual = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ if self.use_residual:
+ x = x + residual
+ return x
+
+class LoRAGenerator(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int = 768, # projection dimension
+ hidden_size: int = 768,
+ head_size: int = 512,
+ num_heads: int = 1,
+ num_mlp_layers: int = 1,
+ output_size: int = 768,
+ dropout: float = 0.0
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.num_heads = num_heads
+ self.simple = False
+
+ self.output_size = output_size
+
+ if self.simple:
+ self.head = nn.Linear(input_size, head_size, bias=False)
+ else:
+ self.lin_in = nn.Linear(input_size, hidden_size)
+
+ self.mlp_blocks = nn.Sequential(*[
+ MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
+ ])
+ self.head = nn.Linear(hidden_size, head_size, bias=False)
+ self.norm = nn.LayerNorm(head_size)
+
+ if num_heads == 1:
+ self.output = nn.Linear(head_size, self.output_size)
+ # for each output block. multiply weights by 0.01
+ with torch.no_grad():
+ self.output.weight.data *= 0.01
+ else:
+ head_output_size = output_size // num_heads
+ self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)])
+ # for each output block. multiply weights by 0.01
+ with torch.no_grad():
+ for output in self.outputs:
+ output.weight.data *= 0.01
+
+ # allow get device
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def forward(self, embedding):
+ if len(embedding.shape) == 2:
+ embedding = embedding.unsqueeze(1)
+
+ x = embedding
+
+ if not self.simple:
+ x = self.lin_in(embedding)
+ x = self.mlp_blocks(x)
+ x = self.head(x)
+ x = self.norm(x)
+
+ if self.num_heads == 1:
+ x = self.output(x)
+ else:
+ out_chunks = torch.chunk(x, self.num_heads, dim=1)
+ x = []
+ for out_layer, chunk in zip(self.outputs, out_chunks):
+ x.append(out_layer(chunk))
+ x = torch.cat(x, dim=-1)
+
+ return x.squeeze(1)
+
+
+class InstantLoRAMidModule(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ lora_module: 'LoRAModule',
+ instant_lora_module: 'InstantLoRAModule',
+ up_shape: list = None,
+ down_shape: list = None,
+ ):
+ super(InstantLoRAMidModule, self).__init__()
+ self.up_shape = up_shape
+ self.down_shape = down_shape
+ self.index = index
+ self.lora_module_ref = weakref.ref(lora_module)
+ self.instant_lora_module_ref = weakref.ref(instant_lora_module)
+
+ self.embed = None
+
+ def down_forward(self, x, *args, **kwargs):
+ # get the embed
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ if x.dtype != self.embed.dtype:
+ x = x.to(self.embed.dtype)
+ down_size = math.prod(self.down_shape)
+ down_weight = self.embed[:, :down_size]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if down_weight.shape[0] * 2 == batch_size:
+ down_weight = torch.cat([down_weight] * 2, dim=0)
+
+ weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
+ x_chunks = torch.chunk(x, batch_size, dim=0)
+
+ x_out = []
+ for i in range(batch_size):
+ weight_chunk = weight_chunks[i]
+ x_chunk = x_chunks[i]
+ # reshape
+ weight_chunk = weight_chunk.view(self.down_shape)
+ # check if is conv or linear
+ if len(weight_chunk.shape) == 4:
+ org_module = self.lora_module_ref().orig_module_ref()
+ stride = org_module.stride
+ padding = org_module.padding
+ x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
+ else:
+ # run a simple linear layer with the down weight
+ x_chunk = x_chunk @ weight_chunk.T
+ x_out.append(x_chunk)
+ x = torch.cat(x_out, dim=0)
+ return x
+
+
+ def up_forward(self, x, *args, **kwargs):
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ if x.dtype != self.embed.dtype:
+ x = x.to(self.embed.dtype)
+ up_size = math.prod(self.up_shape)
+ up_weight = self.embed[:, -up_size:]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if up_weight.shape[0] * 2 == batch_size:
+ up_weight = torch.cat([up_weight] * 2, dim=0)
+
+ weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
+ x_chunks = torch.chunk(x, batch_size, dim=0)
+
+ x_out = []
+ for i in range(batch_size):
+ weight_chunk = weight_chunks[i]
+ x_chunk = x_chunks[i]
+ # reshape
+ weight_chunk = weight_chunk.view(self.up_shape)
+ # check if is conv or linear
+ if len(weight_chunk.shape) == 4:
+ padding = 0
+ if weight_chunk.shape[-1] == 3:
+ padding = 1
+ x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
+ else:
+ # run a simple linear layer with the down weight
+ x_chunk = x_chunk @ weight_chunk.T
+ x_out.append(x_chunk)
+ x = torch.cat(x_out, dim=0)
+ return x
+
+
+
+
+class InstantLoRAModule(torch.nn.Module):
+ def __init__(
+ self,
+ vision_hidden_size: int,
+ vision_tokens: int,
+ head_dim: int,
+ num_heads: int, # number of heads in the resampler
+ sd: 'StableDiffusion',
+ config=None
+ ):
+ super(InstantLoRAModule, self).__init__()
+ # self.linear = torch.nn.Linear(2, 1)
+ self.sd_ref = weakref.ref(sd)
+ self.dim = sd.network.lora_dim
+ self.vision_hidden_size = vision_hidden_size
+ self.vision_tokens = vision_tokens
+ self.head_dim = head_dim
+ self.num_heads = num_heads
+
+ # stores the projection vector. Grabbed by modules
+ self.img_embeds: List[torch.Tensor] = None
+
+ # disable merging in. It is slower on inference
+ self.sd_ref().network.can_merge_in = False
+
+ self.ilora_modules = torch.nn.ModuleList()
+
+ lora_modules = self.sd_ref().network.get_all_modules()
+
+ output_size = 0
+
+ self.embed_lengths = []
+ self.weight_mapping = []
+
+ for idx, lora_module in enumerate(lora_modules):
+ module_dict = lora_module.state_dict()
+ down_shape = list(module_dict['lora_down.weight'].shape)
+ up_shape = list(module_dict['lora_up.weight'].shape)
+
+ self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
+
+ module_size = math.prod(down_shape) + math.prod(up_shape)
+ output_size += module_size
+ self.embed_lengths.append(module_size)
+
+
+ # add a new mid module that will take the original forward and add a vector to it
+ # this will be used to add the vector to the original forward
+ instant_module = InstantLoRAMidModule(
+ idx,
+ lora_module,
+ self,
+ up_shape=up_shape,
+ down_shape=down_shape
+ )
+
+ self.ilora_modules.append(instant_module)
+
+ # replace the LoRA forwards
+ lora_module.lora_down.forward = instant_module.down_forward
+ lora_module.lora_up.forward = instant_module.up_forward
+
+
+ self.output_size = output_size
+
+ number_formatted_output_size = "{:,}".format(output_size)
+
+ print(f" ILORA output size: {number_formatted_output_size}")
+
+ # if not evenly divisible, error
+ if self.output_size % self.num_heads != 0:
+ raise ValueError("Output size must be divisible by the number of heads")
+
+ self.head_output_size = self.output_size // self.num_heads
+
+ if vision_tokens > 1:
+ self.resampler = Resampler(
+ dim=vision_hidden_size,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=num_heads, # output tokens
+ embedding_dim=vision_hidden_size,
+ max_seq_len=vision_tokens,
+ output_dim=head_dim,
+ apply_pos_emb=True, # this is new
+ ff_mult=4
+ )
+
+ self.proj_module = LoRAGenerator(
+ input_size=head_dim,
+ hidden_size=head_dim,
+ head_size=head_dim,
+ num_mlp_layers=1,
+ num_heads=self.num_heads,
+ output_size=self.output_size,
+ )
+
+ self.migrate_weight_mapping()
+
+ def migrate_weight_mapping(self):
+ return
+ # # changes the names of the modules to common ones
+ # keymap = self.sd_ref().network.get_keymap()
+ # save_keymap = {}
+ # if keymap is not None:
+ # for ldm_key, diffusers_key in keymap.items():
+ # # invert them
+ # save_keymap[diffusers_key] = ldm_key
+ #
+ # new_keymap = {}
+ # for key, value in self.weight_mapping:
+ # if key in save_keymap:
+ # new_keymap[save_keymap[key]] = value
+ # else:
+ # print(f"Key {key} not found in keymap")
+ # new_keymap[key] = value
+ # self.weight_mapping = new_keymap
+ # else:
+ # print("No keymap found. Using default names")
+ # return
+
+
+ def forward(self, img_embeds):
+ # expand token rank if only rank 2
+ if len(img_embeds.shape) == 2:
+ img_embeds = img_embeds.unsqueeze(1)
+
+ # resample the image embeddings
+ img_embeds = self.resampler(img_embeds)
+ img_embeds = self.proj_module(img_embeds)
+ if len(img_embeds.shape) == 3:
+ # merge the heads
+ img_embeds = img_embeds.mean(dim=1)
+
+ self.img_embeds = []
+ # get all the slices
+ start = 0
+ for length in self.embed_lengths:
+ self.img_embeds.append(img_embeds[:, start:start+length])
+ start += length
+
+
+ def get_additional_save_metadata(self) -> Dict[str, Any]:
+ # save the weight mapping
+ return {
+ "weight_mapping": self.weight_mapping,
+ "num_heads": self.num_heads,
+ "vision_hidden_size": self.vision_hidden_size,
+ "head_dim": self.head_dim,
+ "vision_tokens": self.vision_tokens,
+ "output_size": self.output_size,
+ }
+
diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46bd0a6d51a0e15856217a92c9aa27cf304287d
--- /dev/null
+++ b/toolkit/models/ilora2.py
@@ -0,0 +1,419 @@
+import math
+import weakref
+
+from toolkit.config_modules import AdapterConfig
+import torch
+import torch.nn as nn
+from typing import TYPE_CHECKING, List, Dict, Any
+from toolkit.models.clip_fusion import ZipperBlock
+from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
+import sys
+from toolkit.paths import REPOS_ROOT
+
+sys.path.append(REPOS_ROOT)
+from ipadapter.ip_adapter.resampler import Resampler
+from collections import OrderedDict
+
+if TYPE_CHECKING:
+ from toolkit.lora_special import LoRAModule
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+class MLP(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
+ super().__init__()
+ if use_residual:
+ assert in_dim == out_dim
+ self.layernorm = nn.LayerNorm(in_dim)
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.use_residual = use_residual
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ residual = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x = self.dropout(x)
+ if self.use_residual:
+ x = x + residual
+ return x
+
+
+class LoRAGenerator(torch.nn.Module):
+ def __init__(
+ self,
+ input_size: int = 768, # projection dimension
+ hidden_size: int = 768,
+ head_size: int = 512,
+ num_heads: int = 1,
+ num_mlp_layers: int = 1,
+ output_size: int = 768,
+ dropout: float = 0.0
+ ):
+ super().__init__()
+ self.input_size = input_size
+ self.num_heads = num_heads
+ self.simple = False
+
+ self.output_size = output_size
+
+ if self.simple:
+ self.head = nn.Linear(input_size, head_size, bias=False)
+ else:
+ self.lin_in = nn.Linear(input_size, hidden_size)
+
+ self.mlp_blocks = nn.Sequential(*[
+ MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in
+ range(num_mlp_layers)
+ ])
+ self.head = nn.Linear(hidden_size, head_size, bias=False)
+ self.norm = nn.LayerNorm(head_size)
+
+ if num_heads == 1:
+ self.output = nn.Linear(head_size, self.output_size)
+ # for each output block. multiply weights by 0.01
+ with torch.no_grad():
+ self.output.weight.data *= 0.01
+ else:
+ head_output_size = output_size // num_heads
+ self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)])
+ # for each output block. multiply weights by 0.01
+ with torch.no_grad():
+ for output in self.outputs:
+ output.weight.data *= 0.01
+
+ # allow get device
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def forward(self, embedding):
+ if len(embedding.shape) == 2:
+ embedding = embedding.unsqueeze(1)
+
+ x = embedding
+
+ if not self.simple:
+ x = self.lin_in(embedding)
+ x = self.mlp_blocks(x)
+ x = self.head(x)
+ x = self.norm(x)
+
+ if self.num_heads == 1:
+ x = self.output(x)
+ else:
+ out_chunks = torch.chunk(x, self.num_heads, dim=1)
+ x = []
+ for out_layer, chunk in zip(self.outputs, out_chunks):
+ x.append(out_layer(chunk))
+ x = torch.cat(x, dim=-1)
+
+ return x.squeeze(1)
+
+
+class InstantLoRAMidModule(torch.nn.Module):
+ def __init__(
+ self,
+ index: int,
+ lora_module: 'LoRAModule',
+ instant_lora_module: 'InstantLoRAModule',
+ up_shape: list = None,
+ down_shape: list = None,
+ ):
+ super(InstantLoRAMidModule, self).__init__()
+ self.up_shape = up_shape
+ self.down_shape = down_shape
+ self.index = index
+ self.lora_module_ref = weakref.ref(lora_module)
+ self.instant_lora_module_ref = weakref.ref(instant_lora_module)
+
+ self.do_up = instant_lora_module.config.ilora_up
+ self.do_down = instant_lora_module.config.ilora_down
+ self.do_mid = instant_lora_module.config.ilora_mid
+
+ self.down_dim = self.down_shape[1] if self.do_down else 0
+ self.mid_dim = self.up_shape[1] if self.do_mid else 0
+ self.out_dim = self.up_shape[0] if self.do_up else 0
+
+ self.embed = None
+
+ def down_forward(self, x, *args, **kwargs):
+ if not self.do_down:
+ return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
+ # get the embed
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ down_weight = self.embed[:, :self.down_dim]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if down_weight.shape[0] * 2 == batch_size:
+ down_weight = torch.cat([down_weight] * 2, dim=0)
+
+ try:
+ if len(x.shape) == 4:
+ # conv
+ down_weight = down_weight.view(batch_size, -1, 1, 1)
+ if x.shape[1] != down_weight.shape[1]:
+ raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
+ elif len(x.shape) == 2:
+ down_weight = down_weight.view(batch_size, -1)
+ if x.shape[1] != down_weight.shape[1]:
+ raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
+ else:
+ down_weight = down_weight.view(batch_size, 1, -1)
+ if x.shape[2] != down_weight.shape[2]:
+ raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
+ x = x * down_weight
+ x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
+ except Exception as e:
+ print(e)
+ raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}")
+
+ return x
+
+ def up_forward(self, x, *args, **kwargs):
+ # do mid here
+ x = self.mid_forward(x, *args, **kwargs)
+ if not self.do_up:
+ return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
+ # get the embed
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ up_weight = self.embed[:, -self.out_dim:]
+
+ batch_size = x.shape[0]
+
+ # unconditional
+ if up_weight.shape[0] * 2 == batch_size:
+ up_weight = torch.cat([up_weight] * 2, dim=0)
+
+ try:
+ if len(x.shape) == 4:
+ # conv
+ up_weight = up_weight.view(batch_size, -1, 1, 1)
+ elif len(x.shape) == 2:
+ up_weight = up_weight.view(batch_size, -1)
+ else:
+ up_weight = up_weight.view(batch_size, 1, -1)
+ x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs)
+ x = x * up_weight
+ except Exception as e:
+ print(e)
+ raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}")
+
+ return x
+
+ def mid_forward(self, x, *args, **kwargs):
+ if not self.do_mid:
+ return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs)
+ batch_size = x.shape[0]
+ # get the embed
+ self.embed = self.instant_lora_module_ref().img_embeds[self.index]
+ mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim]
+
+ # unconditional
+ if mid_weight.shape[0] * 2 == batch_size:
+ mid_weight = torch.cat([mid_weight] * 2, dim=0)
+
+ weight_chunks = torch.chunk(mid_weight, batch_size, dim=0)
+ x_chunks = torch.chunk(x, batch_size, dim=0)
+
+ x_out = []
+ for i in range(batch_size):
+ weight_chunk = weight_chunks[i]
+ x_chunk = x_chunks[i]
+ # reshape
+ if len(x_chunk.shape) == 4:
+ # conv
+ weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1)
+ else:
+ weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim)
+ # check if is conv or linear
+ if len(weight_chunk.shape) == 4:
+ padding = 0
+ if weight_chunk.shape[-1] == 3:
+ padding = 1
+ x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
+ else:
+ # run a simple linear layer with the down weight
+ x_chunk = x_chunk @ weight_chunk.T
+ x_out.append(x_chunk)
+ x = torch.cat(x_out, dim=0)
+ return x
+
+
+class InstantLoRAModule(torch.nn.Module):
+ def __init__(
+ self,
+ vision_hidden_size: int,
+ vision_tokens: int,
+ head_dim: int,
+ num_heads: int, # number of heads in the resampler
+ sd: 'StableDiffusion',
+ config: AdapterConfig
+ ):
+ super(InstantLoRAModule, self).__init__()
+ # self.linear = torch.nn.Linear(2, 1)
+ self.sd_ref = weakref.ref(sd)
+ self.dim = sd.network.lora_dim
+ self.vision_hidden_size = vision_hidden_size
+ self.vision_tokens = vision_tokens
+ self.head_dim = head_dim
+ self.num_heads = num_heads
+
+ self.config: AdapterConfig = config
+
+ # stores the projection vector. Grabbed by modules
+ self.img_embeds: List[torch.Tensor] = None
+
+ # disable merging in. It is slower on inference
+ self.sd_ref().network.can_merge_in = False
+
+ self.ilora_modules = torch.nn.ModuleList()
+
+ lora_modules = self.sd_ref().network.get_all_modules()
+
+ output_size = 0
+
+ self.embed_lengths = []
+ self.weight_mapping = []
+
+ for idx, lora_module in enumerate(lora_modules):
+ module_dict = lora_module.state_dict()
+ down_shape = list(module_dict['lora_down.weight'].shape)
+ up_shape = list(module_dict['lora_up.weight'].shape)
+
+ self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
+
+ #
+ # module_size = math.prod(down_shape) + math.prod(up_shape)
+
+ # conv weight shape is (out_channels, in_channels, kernel_size, kernel_size)
+ # linear weight shape is (out_features, in_features)
+
+ # just doing in dim and out dim
+ in_dim = down_shape[1] if self.config.ilora_down else 0
+ mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0
+ out_dim = up_shape[0] if self.config.ilora_up else 0
+ module_size = in_dim + mid_dim + out_dim
+
+ output_size += module_size
+ self.embed_lengths.append(module_size)
+
+ # add a new mid module that will take the original forward and add a vector to it
+ # this will be used to add the vector to the original forward
+ instant_module = InstantLoRAMidModule(
+ idx,
+ lora_module,
+ self,
+ up_shape=up_shape,
+ down_shape=down_shape
+ )
+
+ self.ilora_modules.append(instant_module)
+
+ # replace the LoRA forwards
+ lora_module.lora_down.orig_forward = lora_module.lora_down.forward
+ lora_module.lora_down.forward = instant_module.down_forward
+ lora_module.lora_up.orig_forward = lora_module.lora_up.forward
+ lora_module.lora_up.forward = instant_module.up_forward
+
+ self.output_size = output_size
+
+ number_formatted_output_size = "{:,}".format(output_size)
+
+ print(f" ILORA output size: {number_formatted_output_size}")
+
+ # if not evenly divisible, error
+ if self.output_size % self.num_heads != 0:
+ raise ValueError("Output size must be divisible by the number of heads")
+
+ self.head_output_size = self.output_size // self.num_heads
+
+ if vision_tokens > 1:
+ self.resampler = Resampler(
+ dim=vision_hidden_size,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=num_heads, # output tokens
+ embedding_dim=vision_hidden_size,
+ max_seq_len=vision_tokens,
+ output_dim=head_dim,
+ apply_pos_emb=True, # this is new
+ ff_mult=4
+ )
+
+ self.proj_module = LoRAGenerator(
+ input_size=head_dim,
+ hidden_size=head_dim,
+ head_size=head_dim,
+ num_mlp_layers=1,
+ num_heads=self.num_heads,
+ output_size=self.output_size,
+ )
+
+ self.migrate_weight_mapping()
+
+ def migrate_weight_mapping(self):
+ return
+ # # changes the names of the modules to common ones
+ # keymap = self.sd_ref().network.get_keymap()
+ # save_keymap = {}
+ # if keymap is not None:
+ # for ldm_key, diffusers_key in keymap.items():
+ # # invert them
+ # save_keymap[diffusers_key] = ldm_key
+ #
+ # new_keymap = {}
+ # for key, value in self.weight_mapping:
+ # if key in save_keymap:
+ # new_keymap[save_keymap[key]] = value
+ # else:
+ # print(f"Key {key} not found in keymap")
+ # new_keymap[key] = value
+ # self.weight_mapping = new_keymap
+ # else:
+ # print("No keymap found. Using default names")
+ # return
+
+ def forward(self, img_embeds):
+ # expand token rank if only rank 2
+ if len(img_embeds.shape) == 2:
+ img_embeds = img_embeds.unsqueeze(1)
+
+ # resample the image embeddings
+ img_embeds = self.resampler(img_embeds)
+ img_embeds = self.proj_module(img_embeds)
+ if len(img_embeds.shape) == 3:
+ # merge the heads
+ img_embeds = img_embeds.mean(dim=1)
+
+ self.img_embeds = []
+ # get all the slices
+ start = 0
+ for length in self.embed_lengths:
+ self.img_embeds.append(img_embeds[:, start:start + length])
+ start += length
+
+ def get_additional_save_metadata(self) -> Dict[str, Any]:
+ # save the weight mapping
+ return {
+ "weight_mapping": self.weight_mapping,
+ "num_heads": self.num_heads,
+ "vision_hidden_size": self.vision_hidden_size,
+ "head_dim": self.head_dim,
+ "vision_tokens": self.vision_tokens,
+ "output_size": self.output_size,
+ "do_up": self.config.ilora_up,
+ "do_mid": self.config.ilora_mid,
+ "do_down": self.config.ilora_down,
+ }
diff --git a/toolkit/models/pixtral_vision.py b/toolkit/models/pixtral_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..815f33101ffb4cfc672cefde6f94fd01ae198783
--- /dev/null
+++ b/toolkit/models/pixtral_vision.py
@@ -0,0 +1,618 @@
+import math
+from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING
+import os
+import torch
+import torch.nn as nn
+from dataclasses import dataclass
+from huggingface_hub import snapshot_download
+from safetensors.torch import load_file
+import json
+
+if TYPE_CHECKING:
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim: int, hidden_dim: int, **kwargs):
+ super().__init__()
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # type: ignore
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
+
+
+def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
+ values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
+ return keys, values
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis[:, None, :]
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ head_dim: int,
+ n_kv_heads: int,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.n_heads: int = n_heads
+ self.head_dim: int = head_dim
+ self.n_kv_heads: int = n_kv_heads
+
+ self.repeats = self.n_heads // self.n_kv_heads
+
+ self.scale = self.head_dim ** -0.5
+
+ self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
+ self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+ self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
+ self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ cache: Optional[Any] = None,
+ mask: Optional['BlockDiagonalMask'] = None,
+ ) -> torch.Tensor:
+ from xformers.ops.fmha import memory_efficient_attention
+ assert mask is None or cache is None
+ seqlen_sum, _ = x.shape
+
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+ xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
+ xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+ xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+ if cache is None:
+ key, val = xk, xv
+ elif cache.prefill:
+ key, val = cache.interleave_kv(xk, xv)
+ cache.update(xk, xv)
+ else:
+ cache.update(xk, xv)
+ key, val = cache.key, cache.value
+ key = key.view(seqlen_sum * cache.max_seq_len,
+ self.n_kv_heads, self.head_dim)
+ val = val.view(seqlen_sum * cache.max_seq_len,
+ self.n_kv_heads, self.head_dim)
+
+ # Repeat keys and values to match number of query heads
+ key, val = repeat_kv(key, val, self.repeats, dim=1)
+
+ # xformers requires (B=1, S, H, D)
+ xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
+ output = memory_efficient_attention(
+ xq, key, val, mask if cache is None else cache.mask)
+ output = output.view(seqlen_sum, self.n_heads * self.head_dim)
+
+ assert isinstance(output, torch.Tensor)
+
+ return self.wo(output) # type: ignore
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ head_dim: int,
+ norm_eps: float,
+ **kwargs,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.dim = dim
+ self.attention = Attention(
+ dim=dim,
+ n_heads=n_heads,
+ head_dim=head_dim,
+ n_kv_heads=n_kv_heads,
+ )
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
+
+ self.feed_forward: nn.Module
+ self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ cache: Optional[Any] = None,
+ mask: Optional['BlockDiagonalMask'] = None,
+ ) -> torch.Tensor:
+ r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
+ h = x + r
+ r = self.feed_forward.forward(self.ffn_norm(h))
+ out = h + r
+ return out
+
+
+@dataclass
+class VisionEncoderArgs:
+ hidden_size: int
+ num_channels: int
+ image_size: int
+ patch_size: int
+ intermediate_size: int
+ num_hidden_layers: int
+ num_attention_heads: int
+ rope_theta: float = 1e4 # for rope-2D
+ image_token_id: int = 10
+
+
+def precompute_freqs_cis_2d(
+ dim: int,
+ height: int,
+ width: int,
+ theta: float,
+) -> torch.Tensor:
+ """
+ freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
+ (height, width) position tuples
+ """
+ # (dim / 2) frequency bases
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
+
+ h = torch.arange(height, device=freqs.device)
+ w = torch.arange(width, device=freqs.device)
+
+ freqs_h = torch.outer(h, freqs[::2]).float()
+ freqs_w = torch.outer(w, freqs[1::2]).float()
+ freqs_2d = torch.cat(
+ [
+ freqs_h[:, None, :].repeat(1, width, 1),
+ freqs_w[None, :, :].repeat(height, 1, 1),
+ ],
+ dim=-1,
+ )
+ return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
+
+
+def position_meshgrid(
+ patch_embeds_list: list[torch.Tensor],
+) -> torch.Tensor:
+ positions = torch.cat(
+ [
+ torch.stack(
+ torch.meshgrid(
+ torch.arange(p.shape[-2]),
+ torch.arange(p.shape[-1]),
+ indexing="ij",
+ ),
+ dim=-1,
+ ).reshape(-1, 2)
+ for p in patch_embeds_list
+ ]
+ )
+ return positions
+
+
+class PixtralVisionEncoder(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 1024,
+ num_channels: int = 3,
+ image_size: int = 1024,
+ patch_size: int = 16,
+ intermediate_size: int = 4096,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ rope_theta: float = 1e4, # for rope-2D
+ image_token_id: int = 10,
+ **kwargs,
+ ):
+ super().__init__()
+ self.args = VisionEncoderArgs(
+ hidden_size=hidden_size,
+ num_channels=num_channels,
+ image_size=image_size,
+ patch_size=patch_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ rope_theta=rope_theta,
+ image_token_id=image_token_id,
+ )
+ args = self.args
+ self.patch_conv = nn.Conv2d(
+ in_channels=args.num_channels,
+ out_channels=args.hidden_size,
+ kernel_size=args.patch_size,
+ stride=args.patch_size,
+ bias=False,
+ )
+ self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
+ self.transformer = VisionTransformerBlocks(args)
+
+ head_dim = self.args.hidden_size // self.args.num_attention_heads
+ assert head_dim % 2 == 0, "ROPE requires even head_dim"
+ self._freqs_cis: Optional[torch.Tensor] = None
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
+ if os.path.isdir(pretrained_model_name_or_path):
+ model_folder = pretrained_model_name_or_path
+ else:
+ model_folder = snapshot_download(pretrained_model_name_or_path)
+
+ # make sure there is a config
+ if not os.path.exists(os.path.join(model_folder, "config.json")):
+ raise ValueError(f"Could not find config.json in {model_folder}")
+
+ # load config
+ with open(os.path.join(model_folder, "config.json"), "r") as f:
+ config = json.load(f)
+
+ model = cls(**config)
+
+ # see if there is a state_dict
+ if os.path.exists(os.path.join(model_folder, "model.safetensors")):
+ state_dict = load_file(os.path.join(
+ model_folder, "model.safetensors"))
+ model.load_state_dict(state_dict)
+
+ return model
+
+ @property
+ def max_patches_per_side(self) -> int:
+ return self.args.image_size // self.args.patch_size
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def freqs_cis(self) -> torch.Tensor:
+ if self._freqs_cis is None:
+ self._freqs_cis = precompute_freqs_cis_2d(
+ dim=self.args.hidden_size // self.args.num_attention_heads,
+ height=self.max_patches_per_side,
+ width=self.max_patches_per_side,
+ theta=self.args.rope_theta,
+ )
+
+ if self._freqs_cis.device != self.device:
+ self._freqs_cis = self._freqs_cis.to(device=self.device)
+
+ return self._freqs_cis
+
+ def forward(
+ self,
+ images: List[torch.Tensor],
+ ) -> torch.Tensor:
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+ """
+ Args:
+ images: list of N_img images of variable sizes, each of shape (C, H, W)
+
+ Returns:
+ image_features: tensor of token features for all tokens of all images of
+ shape (N_toks, D)
+ """
+ assert isinstance(
+ images, list), f"Expected list of images, got {type(images)}"
+ assert all(len(img.shape) == 3 for img in
+ images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
+ # pass images through initial convolution independently
+ patch_embeds_list = [self.patch_conv(
+ img.unsqueeze(0)).squeeze(0) for img in images]
+
+ # flatten to a single sequence
+ patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
+ for p in patch_embeds_list], dim=0)
+ patch_embeds = self.ln_pre(patch_embeds)
+
+ # positional embeddings
+ positions = position_meshgrid(patch_embeds_list).to(self.device)
+ freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
+
+ # pass through Transformer with a block diagonal mask delimiting images
+ mask = BlockDiagonalMask.from_seqlens(
+ [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
+ )
+ out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
+
+ # remove batch dimension of the single sequence
+ return out # type: ignore[no-any-return]
+
+
+class VisionLanguageAdapter(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int):
+ super().__init__()
+ self.w_in = nn.Linear(
+ in_dim,
+ out_dim,
+ bias=True,
+ )
+ self.gelu = nn.GELU()
+ self.w_out = nn.Linear(out_dim, out_dim, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # type: ignore[no-any-return]
+ return self.w_out(self.gelu(self.w_in(x)))
+
+
+class VisionTransformerBlocks(nn.Module):
+ def __init__(self, args: VisionEncoderArgs):
+ super().__init__()
+ self.layers = torch.nn.ModuleList()
+ for _ in range(args.num_hidden_layers):
+ self.layers.append(
+ TransformerBlock(
+ dim=args.hidden_size,
+ hidden_dim=args.intermediate_size,
+ n_heads=args.num_attention_heads,
+ n_kv_heads=args.num_attention_heads,
+ head_dim=args.hidden_size // args.num_attention_heads,
+ norm_eps=1e-5,
+ )
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: 'BlockDiagonalMask',
+ freqs_cis: Optional[torch.Tensor],
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(x, mask=mask, freqs_cis=freqs_cis)
+ return x
+
+
+DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB
+DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB
+
+
+def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize a tensor image with mean and standard deviation.
+
+ Args:
+ image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1].
+ mean (torch.Tensor): Mean for each channel.
+ std (torch.Tensor): Standard deviation for each channel.
+
+ Returns:
+ torch.Tensor: Normalized image with shape (C, H, W).
+ """
+ assert image.shape[0] == len(mean) == len(
+ std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
+
+ # Reshape mean and std to (C, 1, 1) for broadcasting
+ mean = mean.view(-1, 1, 1)
+ std = std.view(-1, 1, 1)
+
+ return (image - mean) / std
+
+
+def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor:
+ """
+ Resize and normalize the input image.
+
+ Args:
+ image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1].
+ new_size (tuple[int, int]): Target size (height, width) for resizing.
+
+ Returns:
+ torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W).
+ """
+ # Resize the image
+ resized_image = torch.nn.functional.interpolate(
+ image.unsqueeze(0),
+ size=new_size,
+ mode='bicubic',
+ align_corners=False
+ ).squeeze(0)
+
+ # Normalize the image
+ normalized_image = normalize(
+ resized_image,
+ torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype),
+ torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype)
+ )
+
+ return normalized_image
+
+
+class PixtralVisionImagePreprocessor:
+ def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
+ self.image_patch_size = image_patch_size
+ self.max_image_size = max_image_size
+ self.image_token = 10
+
+ def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
+ w: Union[int, float]
+ h: Union[int, float]
+
+ if max_image_size is None:
+ max_image_size = self.max_image_size
+
+ w, h = img.shape[-1], img.shape[-2]
+
+ # originally, pixtral used the largest of the 2 dimensions, but we
+ # will use the base size of the image based on number of pixels.
+ # ratio = max(h / self.max_image_size, w / self.max_image_size) # original
+
+ base_size = int(math.sqrt(w * h))
+ ratio = base_size / max_image_size
+ if ratio > 1:
+ w = round(w / ratio)
+ h = round(h / ratio)
+
+ width_tokens = (w - 1) // self.image_patch_size + 1
+ height_tokens = (h - 1) // self.image_patch_size + 1
+
+ return width_tokens, height_tokens
+
+ def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
+ """
+ Converts ImageChunks to numpy image arrays and image token ids
+
+ Args:
+ image torch tensor with values 0-1 and shape of (C, H, W)
+
+ Returns:
+ processed_image: tensor of token features for all tokens of all images of
+ """
+ # should not have batch
+ if len(image.shape) == 4:
+ raise ValueError(
+ f"Expected image with shape (C, H, W), got {image.shape}")
+
+ if image.min() < 0.0 or image.max() > 1.0:
+ raise ValueError(
+ f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
+
+ if max_image_size is None:
+ max_image_size = self.max_image_size
+
+ w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
+ assert w > 0
+ assert h > 0
+
+ new_image_size = (
+ w * self.image_patch_size,
+ h * self.image_patch_size,
+ )
+
+ processed_image = transform_image(image, new_image_size)
+
+ return processed_image
+
+
+class PixtralVisionImagePreprocessorCompatibleReturn:
+ def __init__(self, pixel_values) -> None:
+ self.pixel_values = pixel_values
+
+
+# Compatable version with ai toolkit flow
+class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
+ def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
+ super().__init__(
+ image_patch_size=image_patch_size,
+ max_image_size=max_image_size
+ )
+ self.size = {
+ 'height': max_image_size,
+ 'width': max_image_size
+ }
+ self.max_image_size = max_image_size
+ self.image_mean = DATASET_MEAN
+ self.image_std = DATASET_STD
+
+ def __call__(
+ self,
+ images,
+ return_tensors="pt",
+ do_resize=True,
+ do_rescale=False,
+ max_image_size=None,
+ ) -> torch.Tensor:
+ if max_image_size is None:
+ max_image_size = self.max_image_size
+ out_stack = []
+ if len(images.shape) == 3:
+ images = images.unsqueeze(0)
+ for i in range(images.shape[0]):
+ image = images[i]
+ processed_image = super().__call__(image, max_image_size=max_image_size)
+ out_stack.append(processed_image)
+
+ output = torch.stack(out_stack, dim=0)
+ return PixtralVisionImagePreprocessorCompatibleReturn(output)
+
+
+class PixtralVisionEncoderCompatibleReturn:
+ def __init__(self, hidden_states) -> None:
+ self.hidden_states = hidden_states
+
+
+class PixtralVisionEncoderCompatibleConfig:
+ def __init__(self):
+ self.image_size = 1024
+ self.hidden_size = 1024
+ self.patch_size = 16
+
+
+class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
+ def __init__(
+ self,
+ hidden_size: int = 1024,
+ num_channels: int = 3,
+ image_size: int = 1024,
+ patch_size: int = 16,
+ intermediate_size: int = 4096,
+ num_hidden_layers: int = 24,
+ num_attention_heads: int = 16,
+ rope_theta: float = 1e4, # for rope-2D
+ image_token_id: int = 10,
+ **kwargs,
+ ):
+ super().__init__(
+ hidden_size=hidden_size,
+ num_channels=num_channels,
+ image_size=image_size,
+ patch_size=patch_size,
+ intermediate_size=intermediate_size,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ rope_theta=rope_theta,
+ image_token_id=image_token_id,
+ )
+ self.config = PixtralVisionEncoderCompatibleConfig()
+
+ def forward(
+ self,
+ images,
+ output_hidden_states=True,
+ ) -> torch.Tensor:
+ out_stack = []
+ if len(images.shape) == 3:
+ images = images.unsqueeze(0)
+ for i in range(images.shape[0]):
+ image = images[i]
+ # must be in an array
+ image_output = super().forward([image])
+ out_stack.append(image_output)
+
+ output = torch.stack(out_stack, dim=0)
+ return PixtralVisionEncoderCompatibleReturn([output])
diff --git a/toolkit/models/redux.py b/toolkit/models/redux.py
new file mode 100644
index 0000000000000000000000000000000000000000..609ac50ae7f1404cfd85c63532339fcf94ae60c3
--- /dev/null
+++ b/toolkit/models/redux.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn as nn
+
+
+class ReduxImageEncoder(torch.nn.Module):
+ def __init__(
+ self,
+ redux_dim: int = 1152,
+ txt_in_features: int = 4096,
+ device=None,
+ dtype=None,
+ ) -> None:
+ super().__init__()
+ self.redux_dim = redux_dim
+ self.device = device
+ self.dtype = dtype
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
+ self.redux_down = nn.Linear(
+ txt_in_features * 3, txt_in_features, dtype=dtype)
+
+ def forward(self, sigclip_embeds) -> torch.Tensor:
+ x = self.redux_up(sigclip_embeds)
+ x = torch.nn.functional.silu(x)
+
+ projected_x = self.redux_down(x)
+ return projected_x
diff --git a/toolkit/models/single_value_adapter.py b/toolkit/models/single_value_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..9284d02093c0be9fd16d5effb794a590d8449a4c
--- /dev/null
+++ b/toolkit/models/single_value_adapter.py
@@ -0,0 +1,402 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import weakref
+from typing import Union, TYPE_CHECKING
+
+from diffusers import Transformer2DModel
+from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
+from toolkit.paths import REPOS_ROOT
+sys.path.append(REPOS_ROOT)
+
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+ from toolkit.custom_adapter import CustomAdapter
+
+class AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+class SingleValueAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Custom TE for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ adapter
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
+ adapter_hidden_size=None, has_bias=False, **kwargs):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+
+ self.hidden_size = hidden_size
+ self.adapter_hidden_size = adapter_hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+
+ self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+ self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+ # return False
+
+ @property
+ def unconditional_embeds(self):
+ return self.adapter_ref().adapter_ref().unconditional_embeds
+
+ @property
+ def conditional_embeds(self):
+ return self.adapter_ref().adapter_ref().conditional_embeds
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ is_active = self.adapter_ref().is_active
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ # will be none if disabled
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # only use one TE or the other. If our adapter is active only use ours
+ if self.is_active and self.conditional_embeds is not None:
+
+ adapter_hidden_states = self.conditional_embeds
+ if adapter_hidden_states.shape[0] < batch_size:
+ # doing cfg
+ adapter_hidden_states = torch.cat([
+ self.unconditional_embeds,
+ adapter_hidden_states
+ ], dim=0)
+ # needs to be shape (batch, 1, 1)
+ if len(adapter_hidden_states.shape) == 2:
+ adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
+ # conditional_batch_size = adapter_hidden_states.shape[0]
+ # conditional_query = query
+
+ # for ip-adapter
+ vd_key = self.to_k_adapter(adapter_hidden_states)
+ vd_value = self.to_v_adapter(adapter_hidden_states)
+
+ vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ vd_hidden_states = F.scaled_dot_product_attention(
+ query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ vd_hidden_states = vd_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * vd_hidden_states
+
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SingleValueAdapter(torch.nn.Module):
+ def __init__(
+ self,
+ adapter: 'CustomAdapter',
+ sd: 'StableDiffusion',
+ num_values: int = 1,
+ ):
+ super(SingleValueAdapter, self).__init__()
+ is_pixart = sd.is_pixart
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.token_size = num_values
+
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+
+ attn_processor_keys = []
+ if is_pixart:
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
+
+ # cross attention
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
+
+ else:
+ attn_processor_keys = list(sd.unet.attn_processors.keys())
+
+ for name in attn_processor_keys:
+ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = sd.unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = sd.unet.config['block_out_channels'][block_id]
+ elif name.startswith("transformer"):
+ hidden_size = sd.unet.config['cross_attention_dim']
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ layer_name = name.split(".processor")[0]
+ to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
+ to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
+ # if is_pixart:
+ # to_k_bias = unet_sd[layer_name + ".to_k.bias"]
+ # to_v_bias = unet_sd[layer_name + ".to_v.bias"]
+ # else:
+ # to_k_bias = None
+ # to_v_bias = None
+
+ # add zero padding to the adapter
+ if to_k_adapter.shape[1] < self.token_size:
+ to_k_adapter = torch.cat([
+ to_k_adapter,
+ torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ to_v_adapter = torch.cat([
+ to_v_adapter,
+ torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ # if is_pixart:
+ # to_k_bias = torch.cat([
+ # to_k_bias,
+ # torch.zeros(self.token_size - to_k_adapter.shape[1]).to(
+ # to_k_adapter.device, dtype=to_k_adapter.dtype)
+ # ],
+ # dim=0
+ # )
+ # to_v_bias = torch.cat([
+ # to_v_bias,
+ # torch.zeros(self.token_size - to_v_adapter.shape[1]).to(
+ # to_k_adapter.device, dtype=to_k_adapter.dtype)
+ # ],
+ # dim=0
+ # )
+ elif to_k_adapter.shape[1] > self.token_size:
+ to_k_adapter = to_k_adapter[:, :self.token_size]
+ to_v_adapter = to_v_adapter[:, :self.token_size]
+ # if is_pixart:
+ # to_k_bias = to_k_bias[:self.token_size]
+ # to_v_bias = to_v_bias[:self.token_size]
+ else:
+ to_k_adapter = to_k_adapter
+ to_v_adapter = to_v_adapter
+ # if is_pixart:
+ # to_k_bias = to_k_bias
+ # to_v_bias = to_v_bias
+
+ weights = {
+ "to_k_adapter.weight": to_k_adapter * 0.01,
+ "to_v_adapter.weight": to_v_adapter * 0.01,
+ }
+ # if is_pixart:
+ # weights["to_k_adapter.bias"] = to_k_bias
+ # weights["to_v_adapter.bias"] = to_v_bias
+
+ attn_procs[name] = SingleValueAdapterAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ adapter=self,
+ adapter_hidden_size=self.token_size,
+ has_bias=False,
+ )
+ attn_procs[name].load_state_dict(weights)
+ if self.sd_ref().is_pixart:
+ # we have to set them ourselves
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
+ module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
+ self.adapter_modules = torch.nn.ModuleList([
+ transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks))
+ ] + [
+ transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks))
+ ])
+ else:
+ sd.unet.set_attn_processor(attn_procs)
+ self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
+
+ # make a getter to see if is active
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ def forward(self, input):
+ return input
diff --git a/toolkit/models/size_agnostic_feature_encoder.py b/toolkit/models/size_agnostic_feature_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a716aec504503afa2c103506876c91bc2b617d07
--- /dev/null
+++ b/toolkit/models/size_agnostic_feature_encoder.py
@@ -0,0 +1,256 @@
+import os
+from typing import Union, Optional
+
+import torch
+import torch.nn as nn
+from transformers.image_processing_utils import BaseImageProcessor
+
+
+class SAFEReducerBlock(nn.Module):
+ """
+ This is the block that reduces the size of an vactor w and h be half. It is designed to be iterative
+ So it is run multiple times to reduce an image to a desired dimension while carrying a shrinking residual
+ along for the ride. This is done to preserve information.
+ """
+ def __init__(self, channels=512):
+ super(SAFEReducerBlock, self).__init__()
+ self.channels = channels
+
+ activation = nn.GELU
+
+ self.reducer = nn.Sequential(
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
+ activation(),
+ nn.BatchNorm2d(channels),
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
+ activation(),
+ nn.BatchNorm2d(channels),
+ nn.AvgPool2d(kernel_size=2, stride=2),
+ )
+ self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ res = self.residual_shrink(x)
+ reduced = self.reducer(x)
+ return reduced + res
+
+
+class SizeAgnosticFeatureEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ num_tokens=8,
+ num_vectors=768,
+ reducer_channels=512,
+ channels=2048,
+ downscale_factor: int = 8,
+ ):
+ super(SizeAgnosticFeatureEncoder, self).__init__()
+ self.num_tokens = num_tokens
+ self.num_vectors = num_vectors
+ self.channels = channels
+ self.reducer_channels = reducer_channels
+ self.gradient_checkpointing = False
+
+ # input is minimum of (bs, 3, 256, 256)
+
+ subpixel_channels = in_channels * downscale_factor ** 2
+
+ # PixelUnshuffle(8 = # (bs, 3, 32, 32) -> (bs, 192, 32, 32)
+ # PixelUnshuffle(16 = # (bs, 3, 16, 16) -> (bs, 48, 16, 16)
+
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor) # (bs, 3, 256, 256) -> (bs, 192, 32, 32)
+
+ self.conv_in = nn.Conv2d(subpixel_channels, reducer_channels, kernel_size=3, padding=1) # (bs, 192, 32, 32) -> (bs, 512, 32, 32)
+
+ # run as many times as needed to get to min feature of 8 on the smallest dimension
+ self.reducer = SAFEReducerBlock(reducer_channels) # (bs, 512, 32, 32) -> (bs, 512, 8, 8)
+
+ self.reduced_out = nn.Conv2d(
+ reducer_channels, self.channels, kernel_size=3, padding=1
+ ) # (bs, 512, 8, 8) -> (bs, 2048, 8, 8)
+
+ # (bs, 2048, 8, 8)
+ self.block1 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 4, 4)
+ self.block2 = SAFEReducerBlock(self.channels) # (bs, 2048, 8, 8) -> (bs, 2048, 2, 2)
+
+ # reduce mean of dims 2 and 3
+ self.adaptive_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Flatten(),
+ )
+
+ # (bs, 2048)
+ # linear layer to (bs, self.num_vectors * self.num_tokens)
+ self.fc1 = nn.Linear(self.channels, self.num_vectors * self.num_tokens)
+
+ # (bs, self.num_vectors * self.num_tokens) = (bs, 8 * 768) = (bs, 6144)
+
+ def forward(self, x):
+ x = self.unshuffle(x)
+ x = self.conv_in(x)
+
+ while True:
+ # reduce until we get as close to 8x8 as possible without going under
+ x = self.reducer(x)
+ if x.shape[2] // 2 < 8 or x.shape[3] // 2 < 8:
+ break
+
+ x = self.reduced_out(x)
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.adaptive_pool(x)
+ x = self.fc1(x)
+
+ # reshape
+ x = x.view(-1, self.num_tokens, self.num_vectors)
+
+ return x
+
+
+class SAFEIPReturn:
+ def __init__(self, pixel_values):
+ self.pixel_values = pixel_values
+
+
+class SAFEImageProcessor(BaseImageProcessor):
+ def __init__(
+ self,
+ max_size=1024,
+ min_size=256,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.max_size = max_size
+ self.min_size = min_size
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: Optional[Union[str, bool]] = None,
+ revision: str = "main",
+ **kwargs,
+ ):
+ # not needed
+ return cls(**kwargs)
+
+ def __call__(
+ self,
+ images,
+ **kwargs
+ ):
+ # TODO allow for random resizing
+ # comes in 0 - 1 range
+ # if any size is smaller than 256, resize to 256
+ # if any size is larger than max_size, resize to max_size
+ if images.min() < -0.3 or images.max() > 1.3:
+ raise ValueError(
+ "images fed into SAFEImageProcessor values must be between 0 and 1. Got min: {}, max: {}".format(
+ images.min(), images.max()
+ ))
+
+ # make sure we have (bs, 3, h, w)
+ while len(images.shape) < 4:
+ images = images.unsqueeze(0)
+
+ # expand to 3 channels if we only have 1 channel
+ if images.shape[1] == 1:
+ images = torch.cat([images, images, images], dim=1)
+
+ width = images.shape[3]
+ height = images.shape[2]
+
+ if width < self.min_size or height < self.min_size:
+ # scale up so that the smallest size is 256
+ if width < height:
+ new_width = self.min_size
+ new_height = int(height * (self.min_size / width))
+ else:
+ new_height = self.min_size
+ new_width = int(width * (self.min_size / height))
+ images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
+ align_corners=False)
+
+ elif width > self.max_size or height > self.max_size:
+ # scale down so that the largest size is max_size but do not shrink the other size below 256
+ if width > height:
+ new_width = self.max_size
+ new_height = int(height * (self.max_size / width))
+ else:
+ new_height = self.max_size
+ new_width = int(width * (self.max_size / height))
+
+ if new_width < self.min_size:
+ new_width = self.min_size
+ new_height = int(height * (self.min_size / width))
+
+ if new_height < self.min_size:
+ new_height = self.min_size
+ new_width = int(width * (self.min_size / height))
+
+ images = nn.functional.interpolate(images, size=(new_height, new_width), mode='bilinear',
+ align_corners=False)
+
+ # if wither side is not divisible by 16, mirror pad to make it so
+ if images.shape[2] % 16 != 0:
+ pad = 16 - (images.shape[2] % 16)
+ pad1 = pad // 2
+ pad2 = pad - pad1
+ images = nn.functional.pad(images, (0, 0, pad1, pad2), mode='reflect')
+ if images.shape[3] % 16 != 0:
+ pad = 16 - (images.shape[3] % 16)
+ pad1 = pad // 2
+ pad2 = pad - pad1
+ images = nn.functional.pad(images, (pad1, pad2, 0, 0), mode='reflect')
+
+ return SAFEIPReturn(images)
+
+
+class SAFEVMConfig:
+ def __init__(
+ self,
+ in_channels=3,
+ num_tokens=8,
+ num_vectors=768,
+ reducer_channels=512,
+ channels=2048,
+ downscale_factor: int = 8,
+ **kwargs
+ ):
+ self.in_channels = in_channels
+ self.num_tokens = num_tokens
+ self.num_vectors = num_vectors
+ self.reducer_channels = reducer_channels
+ self.channels = channels
+ self.downscale_factor = downscale_factor
+ self.image_size = 224
+
+ self.hidden_size = num_vectors
+ self.projection_dim = num_vectors
+
+
+class SAFEVMReturn:
+ def __init__(self, output):
+ self.output = output
+ # todo actually do hidden states. This is just for code compatability for now
+ self.hidden_states = [output for _ in range(13)]
+
+
+class SAFEVisionModel(SizeAgnosticFeatureEncoder):
+ def __init__(self, **kwargs):
+ self.config = SAFEVMConfig(**kwargs)
+ self.image_size = None
+ # super().__init__(**kwargs)
+ super(SAFEVisionModel, self).__init__(**kwargs)
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ # not needed
+ return SAFEVisionModel(**kwargs)
+
+ def forward(self, x, **kwargs):
+ return SAFEVMReturn(super().forward(x))
diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc7679aac14f803ef58f6ad3ed078076234c5695
--- /dev/null
+++ b/toolkit/models/te_adapter.py
@@ -0,0 +1,460 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import weakref
+from typing import Union, TYPE_CHECKING
+
+
+from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
+from diffusers.models.embeddings import PixArtAlphaTextProjection
+
+from toolkit import train_tools
+from toolkit.paths import REPOS_ROOT
+from toolkit.prompt_utils import PromptEmbeds
+from diffusers import Transformer2DModel
+
+sys.path.append(REPOS_ROOT)
+
+from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
+
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
+ from toolkit.custom_adapter import CustomAdapter
+
+
+class TEAdapterCaptionProjection(nn.Module):
+ def __init__(self, caption_channels, adapter: 'TEAdapter'):
+ super().__init__()
+ in_features = caption_channels
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ sd = adapter.sd_ref()
+ self.parent_module_ref = weakref.ref(sd.unet.caption_projection)
+ parent_module = self.parent_module_ref()
+ self.linear_1 = nn.Linear(
+ in_features=in_features,
+ out_features=parent_module.linear_1.out_features,
+ bias=True
+ )
+ self.linear_2 = nn.Linear(
+ in_features=parent_module.linear_2.in_features,
+ out_features=parent_module.linear_2.out_features,
+ bias=True
+ )
+
+ # save the orig forward
+ parent_module.linear_1.orig_forward = parent_module.linear_1.forward
+ parent_module.linear_2.orig_forward = parent_module.linear_2.forward
+
+ # replace original forward
+ parent_module.orig_forward = parent_module.forward
+ parent_module.forward = self.forward
+
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ @property
+ def unconditional_embeds(self):
+ return self.adapter_ref().adapter_ref().unconditional_embeds
+
+ @property
+ def conditional_embeds(self):
+ return self.adapter_ref().adapter_ref().conditional_embeds
+
+ def forward(self, caption):
+ if self.is_active and self.conditional_embeds is not None:
+ adapter_hidden_states = self.conditional_embeds.text_embeds
+ # check if we are doing unconditional
+ if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]:
+ # concat unconditional to match the hidden state batch size
+ if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
+ unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
+ else:
+ unconditional = self.unconditional_embeds.text_embeds
+ adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
+ hidden_states = self.linear_1(adapter_hidden_states)
+ hidden_states = self.parent_module_ref().act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+ else:
+ return self.parent_module_ref().orig_forward(caption)
+
+
+class TEAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Custom TE for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ adapter
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None,
+ adapter_hidden_size=None, layer_name=None):
+ super().__init__()
+ self.layer_name = layer_name
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+
+ self.hidden_size = hidden_size
+ self.adapter_hidden_size = adapter_hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False)
+ self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=False)
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ @property
+ def unconditional_embeds(self):
+ return self.adapter_ref().adapter_ref().unconditional_embeds
+
+ @property
+ def conditional_embeds(self):
+ return self.adapter_ref().adapter_ref().conditional_embeds
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ is_active = self.adapter_ref().is_active
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ # will be none if disabled
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ # only use one TE or the other. If our adapter is active only use ours
+ if self.is_active and self.conditional_embeds is not None:
+ adapter_hidden_states = self.conditional_embeds.text_embeds
+ # check if we are doing unconditional
+ if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]:
+ # concat unconditional to match the hidden state batch size
+ if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
+ unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
+ else:
+ unconditional = self.unconditional_embeds.text_embeds
+ adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
+ # for ip-adapter
+ key = self.to_k_adapter(adapter_hidden_states)
+ value = self.to_v_adapter(adapter_hidden_states)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ try:
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ except RuntimeError:
+ raise RuntimeError(f"key shape: {key.shape}, value shape: {value.shape}")
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ # remove attn mask if doing clip
+ if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
+ attention_mask = None
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class TEAdapter(torch.nn.Module):
+ def __init__(
+ self,
+ adapter: 'CustomAdapter',
+ sd: 'StableDiffusion',
+ te: Union[T5EncoderModel],
+ tokenizer: CLIPTokenizer
+ ):
+ super(TEAdapter, self).__init__()
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.te_ref: weakref.ref = weakref.ref(te)
+ self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
+ self.adapter_modules = []
+ self.caption_projection = None
+ self.embeds_store = []
+ is_pixart = sd.is_pixart
+
+ if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
+ self.token_size = self.te_ref().config.d_model
+ else:
+ self.token_size = self.te_ref().config.hidden_size
+
+ # add text projection if is sdxl
+ self.text_projection = None
+ if sd.is_xl:
+ clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0]
+ self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False)
+
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+ attn_dict_map = {
+
+ }
+ module_idx = 0
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+ attn_processor_keys = []
+ if is_pixart:
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
+
+ # cross attention
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
+
+ else:
+ attn_processor_keys = list(sd.unet.attn_processors.keys())
+
+ attn_processor_names = []
+
+ blocks = []
+ transformer_blocks = []
+ for name in attn_processor_keys:
+ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
+ sd.unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = sd.unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = sd.unet.config['block_out_channels'][block_id]
+ elif name.startswith("transformer"):
+ hidden_size = sd.unet.config['cross_attention_dim']
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ layer_name = name.split(".processor")[0]
+ to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
+ to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
+
+ # add zero padding to the adapter
+ if to_k_adapter.shape[1] < self.token_size:
+ to_k_adapter = torch.cat([
+ to_k_adapter,
+ torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ to_v_adapter = torch.cat([
+ to_v_adapter,
+ torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ elif to_k_adapter.shape[1] > self.token_size:
+ to_k_adapter = to_k_adapter[:, :self.token_size]
+ to_v_adapter = to_v_adapter[:, :self.token_size]
+ else:
+ to_k_adapter = to_k_adapter
+ to_v_adapter = to_v_adapter
+
+ # todo resize to the TE hidden size
+ weights = {
+ "to_k_adapter.weight": to_k_adapter,
+ "to_v_adapter.weight": to_v_adapter,
+ }
+
+ if self.sd_ref().is_pixart:
+ # pixart is much more sensitive
+ weights = {
+ "to_k_adapter.weight": weights["to_k_adapter.weight"] * 0.01,
+ "to_v_adapter.weight": weights["to_v_adapter.weight"] * 0.01,
+ }
+
+ attn_procs[name] = TEAdapterAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.adapter_ref().config.num_tokens,
+ adapter=self,
+ adapter_hidden_size=self.token_size,
+ layer_name=layer_name
+ )
+ attn_procs[name].load_state_dict(weights)
+ self.adapter_modules.append(attn_procs[name])
+ if self.sd_ref().is_pixart:
+ # we have to set them ourselves
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
+ module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
+ self.adapter_modules = torch.nn.ModuleList(
+ [
+ transformer.transformer_blocks[i].attn2.processor for i in
+ range(len(transformer.transformer_blocks))
+ ])
+ self.caption_projection = TEAdapterCaptionProjection(
+ caption_channels=self.token_size,
+ adapter=self,
+ )
+
+ else:
+ sd.unet.set_attn_processor(attn_procs)
+ self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
+
+ # make a getter to see if is active
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ def encode_text(self, text):
+ te: T5EncoderModel = self.te_ref()
+ tokenizer: T5Tokenizer = self.tokenizer_ref()
+ attn_mask_float = None
+
+ # input_ids = tokenizer(
+ # text,
+ # max_length=77,
+ # padding="max_length",
+ # truncation=True,
+ # return_tensors="pt",
+ # ).input_ids.to(te.device)
+ # outputs = te(input_ids=input_ids)
+ # outputs = outputs.last_hidden_state
+ if self.adapter_ref().config.text_encoder_arch == "clip":
+ embeds = train_tools.encode_prompts(
+ tokenizer,
+ te,
+ text,
+ truncate=True,
+ max_length=self.adapter_ref().config.num_tokens,
+ )
+ attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
+
+ elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
+ # just use aura pile
+ embeds, attention_mask = train_tools.encode_prompts_auraflow(
+ tokenizer,
+ te,
+ text,
+ truncate=True,
+ max_length=self.adapter_ref().config.num_tokens,
+ )
+
+ else:
+ embeds, attention_mask = train_tools.encode_prompts_pixart(
+ tokenizer,
+ te,
+ text,
+ truncate=True,
+ max_length=self.adapter_ref().config.num_tokens,
+ )
+ if attention_mask is not None:
+ attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
+ if self.text_projection is not None:
+ # pool the output of embeds ignoring 0 in the attention mask
+ if attn_mask_float is not None:
+ pooled_output = embeds * attn_mask_float.unsqueeze(-1)
+ else:
+ pooled_output = embeds
+
+ # reduce along dim 1 while maintaining batch and dim 2
+ pooled_output_sum = pooled_output.sum(dim=1)
+
+ if attn_mask_float is not None:
+ attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1)
+
+ pooled_output = pooled_output_sum / attn_mask_sum
+
+ pooled_embeds = self.text_projection(pooled_output)
+
+ prompt_embeds = PromptEmbeds(
+ (embeds, pooled_embeds),
+ attention_mask=attention_mask,
+ ).detach()
+
+ else:
+
+ prompt_embeds = PromptEmbeds(
+ embeds,
+ attention_mask=attention_mask,
+ ).detach()
+
+ return prompt_embeds
+
+
+
+ def forward(self, input):
+ return input
diff --git a/toolkit/models/te_aug_adapter.py b/toolkit/models/te_aug_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..02cbbec1a6eb4fdcce7d94067a976fbde496f89c
--- /dev/null
+++ b/toolkit/models/te_aug_adapter.py
@@ -0,0 +1,253 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import weakref
+from typing import Union, TYPE_CHECKING, Optional, Tuple
+
+from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
+from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
+
+from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
+from toolkit.paths import REPOS_ROOT
+from toolkit.resampler import Resampler
+
+sys.path.append(REPOS_ROOT)
+
+from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+ from toolkit.custom_adapter import CustomAdapter
+
+
+class TEAugAdapterCLIPAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'):
+ super().__init__()
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.attn_module_ref: weakref.ref = weakref.ref(attn_module)
+ self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
+ self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
+ # copy the weights from the original module
+ self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01
+ self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01
+ #reset the bias
+ self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001
+ self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001
+
+ self.zipper = ZipperModule(
+ in_size=attn_module.embed_dim,
+ in_tokens=77 * 2,
+ out_size=attn_module.embed_dim,
+ out_tokens=77,
+ hidden_size=attn_module.embed_dim,
+ hidden_tokens=77,
+ )
+ # self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data)
+ # self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data)
+ # #reset the bias
+ # self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data)
+ # self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data)
+
+ # replace the original forward with our forward
+ self.original_forward = attn_module.forward
+ attn_module.forward = self.forward
+
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ attn_module = self.attn_module_ref()
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = attn_module.q_proj(hidden_states) * attn_module.scale
+ key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz)
+ value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim)
+ query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref()
+ if self.adapter_ref().is_active and adapter.conditional_embeds is not None:
+ # apply the adapter
+
+ if adapter.is_unconditional_run:
+ embeds = adapter.unconditional_embeds
+ else:
+ embeds = adapter.conditional_embeds
+ # if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well
+ if embeds.size(0) != bsz:
+ embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0)
+
+ key_states_raw = self.k_proj_adapter(embeds)
+ key_states = attn_module._shape(key_states_raw, -1, bsz)
+ value_states_raw = self.v_proj_adapter(embeds)
+ value_states = attn_module._shape(value_states_raw, -1, bsz)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
+ attn_output_adapter = torch.bmm(attn_probs, value_states)
+
+ if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
+ raise ValueError(
+ f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
+ f" {attn_output_adapter.size()}"
+ )
+
+ attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
+ attn_output_adapter = attn_output_adapter.transpose(1, 2)
+ attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1))
+
+ # attn_output_adapter = attn_module.out_proj(attn_output_adapter)
+ attn_output = attn_output + attn_output_adapter
+
+ attn_output = attn_module.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+class TEAugAdapter(torch.nn.Module):
+ def __init__(
+ self,
+ adapter: 'CustomAdapter',
+ sd: 'StableDiffusion',
+ ):
+ super(TEAugAdapter, self).__init__()
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+
+ if isinstance(sd.text_encoder, list):
+ raise ValueError("Dual text encoders is not yet supported")
+
+ # dim will come from text encoder
+ # dim = sd.unet.config['cross_attention_dim']
+ text_encoder: CLIPTextModel = sd.text_encoder
+ dim = text_encoder.config.hidden_size
+
+ clip_encoder: CLIPEncoder = text_encoder.text_model.encoder
+ # dim = clip_encoder.layers[-1].self_attn
+
+ if hasattr(adapter.vision_encoder.config, 'hidden_sizes'):
+ embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
+ else:
+ embedding_dim = adapter.vision_encoder.config.hidden_size
+
+ image_encoder_state_dict = adapter.vision_encoder.state_dict()
+ # max_seq_len = CLIP tokens + CLS token
+ in_tokens = 257
+ if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
+ # clip
+ in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
+
+ if adapter.config.image_encoder_arch.startswith('convnext'):
+ in_tokens = 16 * 16
+ embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
+
+ out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens
+ self.image_proj_model = ZipperModule(
+ in_size=embedding_dim,
+ in_tokens=in_tokens,
+ out_size=dim,
+ out_tokens=out_tokens,
+ hidden_size=dim,
+ hidden_tokens=out_tokens,
+ )
+ # init adapter modules
+ attn_procs = {}
+ for idx, layer in enumerate(clip_encoder.layers):
+ name = f"clip_attention.{idx}"
+ attn_procs[name] = TEAugAdapterCLIPAttention(
+ layer.self_attn,
+ self
+ )
+
+ self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values()))
+
+ # make a getter to see if is active
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+
+ def forward(self, input):
+ # # apply the adapter
+ input = self.image_proj_model(input)
+ # self.embeds = input
+ return input
diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea3f9bc757143570067c2387ec8a6a96c909690d
--- /dev/null
+++ b/toolkit/models/vd_adapter.py
@@ -0,0 +1,812 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import weakref
+from typing import Union, TYPE_CHECKING, Optional
+from collections import OrderedDict
+
+from diffusers import Transformer2DModel, FluxTransformer2DModel
+from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
+from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
+from transformers import SiglipImageProcessor, SiglipVisionModel
+
+from toolkit.config_modules import AdapterConfig
+from toolkit.paths import REPOS_ROOT
+sys.path.append(REPOS_ROOT)
+
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+ from toolkit.custom_adapter import CustomAdapter
+
+
+# matches distribution of randn
+class Norm(nn.Module):
+ def __init__(self, target_mean=0.0, target_std=1.0, eps=1e-6):
+ super(Norm, self).__init__()
+ self.target_mean = target_mean
+ self.target_std = target_std
+ self.eps = eps
+
+ def forward(self, x):
+ dims = tuple(range(1, x.dim()))
+ mean = x.mean(dim=dims, keepdim=True)
+ std = x.std(dim=dims, keepdim=True)
+
+ # Normalize
+ return self.target_std * (x - mean) / (std + self.eps) + self.target_mean
+
+
+norm_layer = Norm()
+
+class SparseAutoencoder(nn.Module):
+ def __init__(self, input_dim, hidden_dim, output_dim):
+ super(SparseAutoencoder, self).__init__()
+ self.encoder = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, output_dim),
+ )
+ self.norm = Norm()
+ self.decoder = nn.Sequential(
+ nn.Linear(output_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, input_dim),
+ )
+ self.last_run = None
+
+ def forward(self, x):
+ self.last_run = {
+ "input": x
+ }
+ x = self.encoder(x)
+ x = self.norm(x)
+ self.last_run["sparse"] = x
+ x = self.decoder(x)
+ x = self.norm(x)
+ self.last_run["output"] = x
+ return x
+
+
+class MLPR(nn.Module): # MLP with reshaping
+ def __init__(
+ self,
+ in_dim,
+ in_channels,
+ out_dim,
+ out_channels,
+ use_residual=True
+ ):
+ super().__init__()
+ if use_residual:
+ assert in_dim == out_dim
+ # dont normalize if using conv
+ self.layer_norm = nn.LayerNorm(in_dim)
+
+ self.fc1 = nn.Linear(in_dim, out_dim)
+ self.act_fn = nn.GELU()
+ self.conv1 = nn.Conv1d(in_channels, out_channels, 1)
+
+ def forward(self, x):
+ residual = x
+ x = self.layer_norm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.conv1(x)
+ return x
+
+class AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+class VisionDirectAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Custom TE for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ adapter
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
+ adapter_hidden_size=None, has_bias=False, **kwargs):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+
+ self.hidden_size = hidden_size
+ self.adapter_hidden_size = adapter_hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+
+ self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+ self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+ # return False
+
+ @property
+ def unconditional_embeds(self):
+ return self.adapter_ref().adapter_ref().unconditional_embeds
+
+ @property
+ def conditional_embeds(self):
+ return self.adapter_ref().adapter_ref().conditional_embeds
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ is_active = self.adapter_ref().is_active
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ # will be none if disabled
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # only use one TE or the other. If our adapter is active only use ours
+ if self.is_active and self.conditional_embeds is not None:
+
+ adapter_hidden_states = self.conditional_embeds
+ if adapter_hidden_states.shape[0] < batch_size:
+ adapter_hidden_states = torch.cat([
+ self.unconditional_embeds,
+ adapter_hidden_states
+ ], dim=0)
+ # if it is image embeds, we need to add a 1 dim at inx 1
+ if len(adapter_hidden_states.shape) == 2:
+ adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
+ # conditional_batch_size = adapter_hidden_states.shape[0]
+ # conditional_query = query
+
+ # for ip-adapter
+ vd_key = self.to_k_adapter(adapter_hidden_states)
+ vd_value = self.to_v_adapter(adapter_hidden_states)
+
+ vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ vd_hidden_states = F.scaled_dot_product_attention(
+ query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ vd_hidden_states = vd_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * vd_hidden_states
+
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
+ adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+
+ self.hidden_size = hidden_size
+ self.adapter_hidden_size = adapter_hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.block_idx = block_idx
+
+ self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+ self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
+
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+ # return False
+
+ @property
+ def unconditional_embeds(self):
+ return self.adapter_ref().adapter_ref().unconditional_embeds
+
+ @property
+ def conditional_embeds(self):
+ return self.adapter_ref().adapter_ref().conditional_embeds
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # begin ip adapter
+ if self.is_active and self.conditional_embeds is not None:
+ adapter_hidden_states = self.conditional_embeds
+ block_scaler = self.adapter_ref().block_scaler
+ if block_scaler is not None:
+ # add 1 to block scaler so we can decay its weight to 1.0
+ block_scaler = block_scaler[self.block_idx] + 1.0
+
+ if adapter_hidden_states.shape[0] < batch_size:
+ adapter_hidden_states = torch.cat([
+ self.unconditional_embeds,
+ adapter_hidden_states
+ ], dim=0)
+ # if it is image embeds, we need to add a 1 dim at inx 1
+ if len(adapter_hidden_states.shape) == 2:
+ adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
+ # conditional_batch_size = adapter_hidden_states.shape[0]
+ # conditional_query = query
+
+ # for ip-adapter
+ vd_key = self.to_k_adapter(adapter_hidden_states)
+ vd_value = self.to_v_adapter(adapter_hidden_states)
+
+ vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ vd_hidden_states = F.scaled_dot_product_attention(
+ query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ vd_hidden_states = vd_hidden_states.to(query.dtype)
+
+ # scale to block scaler
+ if block_scaler is not None:
+ orig_dtype = vd_hidden_states.dtype
+ if block_scaler.dtype != vd_hidden_states.dtype:
+ vd_hidden_states = vd_hidden_states.to(block_scaler.dtype)
+ vd_hidden_states = vd_hidden_states * block_scaler
+ if block_scaler.dtype != orig_dtype:
+ vd_hidden_states = vd_hidden_states.to(orig_dtype)
+
+ hidden_states = hidden_states + self.scale * vd_hidden_states
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+class VisionDirectAdapter(torch.nn.Module):
+ def __init__(
+ self,
+ adapter: 'CustomAdapter',
+ sd: 'StableDiffusion',
+ vision_model: Union[CLIPVisionModelWithProjection],
+ ):
+ super(VisionDirectAdapter, self).__init__()
+ is_pixart = sd.is_pixart
+ is_flux = sd.is_flux
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.config: AdapterConfig = adapter.config
+ self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
+ self.resampler = None
+ is_pixtral = self.config.image_encoder_arch == "pixtral"
+
+ if adapter.config.clip_layer == "image_embeds":
+ if isinstance(vision_model, SiglipVisionModel):
+ self.token_size = vision_model.config.hidden_size
+ else:
+ self.token_size = vision_model.config.projection_dim
+ else:
+ self.token_size = vision_model.config.hidden_size
+
+ self.mid_size = self.token_size
+
+ if self.config.conv_pooling and self.config.conv_pooling_stacks > 1:
+ self.mid_size = self.mid_size * self.config.conv_pooling_stacks
+
+ # if pixtral, use cross attn dim for more sparse representation if only doing double transformers
+ if is_pixtral and self.config.flux_only_double:
+ if is_flux:
+ hidden_size = 3072
+ else:
+ hidden_size = sd.unet.config['cross_attention_dim']
+ self.mid_size = hidden_size
+
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+
+ attn_processor_keys = []
+ if is_pixart:
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
+
+ # cross attention
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
+
+ elif is_flux:
+ transformer: FluxTransformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ attn_processor_keys.append(f"transformer_blocks.{i}.attn")
+
+ if not self.config.flux_only_double:
+ # single transformer blocks do not have cross attn, but we will do them anyway
+ for i, module in transformer.single_transformer_blocks.named_children():
+ attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
+ else:
+ attn_processor_keys = list(sd.unet.attn_processors.keys())
+
+ current_idx = 0
+
+ for name in attn_processor_keys:
+ if is_flux:
+ cross_attention_dim = None
+ else:
+ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = sd.unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = sd.unet.config['block_out_channels'][block_id]
+ elif name.startswith("transformer") or name.startswith("single_transformer"):
+ if is_flux:
+ hidden_size = 3072
+ else:
+ hidden_size = sd.unet.config['cross_attention_dim']
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None and not is_flux:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ layer_name = name.split(".processor")[0]
+ if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
+ # is quantized
+
+ to_k_adapter = torch.randn(hidden_size, hidden_size) * 0.01
+ to_v_adapter = torch.randn(hidden_size, hidden_size) * 0.01
+ to_k_adapter = to_k_adapter.to(self.sd_ref().torch_dtype)
+ to_v_adapter = to_v_adapter.to(self.sd_ref().torch_dtype)
+ else:
+ to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
+ to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
+
+ # add zero padding to the adapter
+ if to_k_adapter.shape[1] < self.mid_size:
+ to_k_adapter = torch.cat([
+ to_k_adapter,
+ torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ to_v_adapter = torch.cat([
+ to_v_adapter,
+ torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to(
+ to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
+ ],
+ dim=1
+ )
+ elif to_k_adapter.shape[1] > self.mid_size:
+ to_k_adapter = to_k_adapter[:, :self.mid_size]
+ to_v_adapter = to_v_adapter[:, :self.mid_size]
+ # if is_pixart:
+ # to_k_bias = to_k_bias[:self.mid_size]
+ # to_v_bias = to_v_bias[:self.mid_size]
+ else:
+ to_k_adapter = to_k_adapter
+ to_v_adapter = to_v_adapter
+ # if is_pixart:
+ # to_k_bias = to_k_bias
+ # to_v_bias = to_v_bias
+
+ weights = {
+ "to_k_adapter.weight": to_k_adapter * 0.01,
+ "to_v_adapter.weight": to_v_adapter * 0.01,
+ }
+ # if is_pixart:
+ # weights["to_k_adapter.bias"] = to_k_bias
+ # weights["to_v_adapter.bias"] = to_v_bias\
+
+ if is_flux:
+ attn_procs[name] = CustomFluxVDAttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ adapter=self,
+ adapter_hidden_size=self.mid_size,
+ has_bias=False,
+ block_idx=current_idx
+ )
+ else:
+ attn_procs[name] = VisionDirectAdapterAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ adapter=self,
+ adapter_hidden_size=self.mid_size,
+ has_bias=False,
+ )
+ current_idx += 1
+ attn_procs[name].load_state_dict(weights)
+
+ if self.sd_ref().is_pixart:
+ # we have to set them ourselves
+ transformer: Transformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
+ module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
+ self.adapter_modules = torch.nn.ModuleList([
+ transformer.transformer_blocks[i].attn1.processor for i in range(len(transformer.transformer_blocks))
+ ] + [
+ transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks))
+ ])
+ elif self.sd_ref().is_flux:
+ # we have to set them ourselves
+ transformer: FluxTransformer2DModel = sd.unet
+ for i, module in transformer.transformer_blocks.named_children():
+ module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
+
+ if not self.config.flux_only_double:
+ # do single blocks too even though they dont have cross attn
+ for i, module in transformer.single_transformer_blocks.named_children():
+ module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
+
+ if not self.config.flux_only_double:
+ self.adapter_modules = torch.nn.ModuleList(
+ [
+ transformer.transformer_blocks[i].attn.processor for i in
+ range(len(transformer.transformer_blocks))
+ ] + [
+ transformer.single_transformer_blocks[i].attn.processor for i in
+ range(len(transformer.single_transformer_blocks))
+ ]
+ )
+ else:
+ self.adapter_modules = torch.nn.ModuleList(
+ [
+ transformer.transformer_blocks[i].attn.processor for i in
+ range(len(transformer.transformer_blocks))
+ ]
+ )
+ else:
+ sd.unet.set_attn_processor(attn_procs)
+ self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
+
+ num_modules = len(self.adapter_modules)
+ if self.config.train_scaler:
+ self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to(
+ dtype=torch.float32,
+ device=self.sd_ref().device_torch
+ ))
+ self.block_scaler.data = self.block_scaler.data.to(torch.float32)
+ self.block_scaler.requires_grad = True
+ else:
+ self.block_scaler = None
+
+ self.pool = None
+
+ if self.config.num_tokens is not None:
+ # image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
+ # max_seq_len = CLIP tokens + CLS token
+ # max_seq_len = 257
+ # if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
+ # # clip
+ # max_seq_len = int(
+ # image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
+ # self.resampler = MLPR(
+ # in_dim=self.token_size,
+ # in_channels=max_seq_len,
+ # out_dim=self.mid_size,
+ # out_channels=self.config.num_tokens,
+ # )
+ vision_config = self.adapter_ref().vision_encoder.config
+ # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
+ # siglip doesnt add 1
+ sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
+ self.pool = nn.Sequential(
+ nn.Conv1d(sequence_length, self.config.num_tokens, 1, bias=False),
+ Norm(),
+ )
+
+ elif self.config.image_encoder_arch == "pixtral":
+ self.resampler = VisionLanguageAdapter(
+ in_dim=self.token_size,
+ out_dim=self.mid_size,
+ )
+
+ self.sparse_autoencoder = None
+ if self.config.conv_pooling:
+ vision_config = self.adapter_ref().vision_encoder.config
+ # sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2 + 1)
+ # siglip doesnt add 1
+ sequence_length = int((vision_config.image_size / vision_config.patch_size) ** 2)
+ self.pool = nn.Sequential(
+ nn.Conv1d(sequence_length, self.config.conv_pooling_stacks, 1, bias=False),
+ Norm(),
+ )
+ if self.config.sparse_autoencoder_dim is not None:
+ hidden_dim = self.token_size * 2
+ if hidden_dim > self.config.sparse_autoencoder_dim:
+ hidden_dim = self.config.sparse_autoencoder_dim
+ self.sparse_autoencoder = SparseAutoencoder(
+ input_dim=self.token_size,
+ hidden_dim=hidden_dim,
+ output_dim=self.config.sparse_autoencoder_dim
+ )
+
+ if self.config.clip_layer == "image_embeds":
+ self.proj = nn.Linear(self.token_size, self.token_size)
+
+ def state_dict(self, destination=None, prefix='', keep_vars=False):
+ if self.config.train_scaler:
+ # only return the block scaler
+ if destination is None:
+ destination = OrderedDict()
+ destination[prefix + 'block_scaler'] = self.block_scaler
+ return destination
+ return super().state_dict(destination, prefix, keep_vars)
+
+ # make a getter to see if is active
+ @property
+ def is_active(self):
+ return self.adapter_ref().is_active
+
+ def forward(self, input):
+ # block scaler keeps moving dtypes. make sure it is float32 here
+ # todo remove this when we have a real solution
+
+ if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
+ self.block_scaler.data = self.block_scaler.data.to(torch.float32)
+ # if doing image_embeds, normalize here
+ if self.config.clip_layer == "image_embeds":
+ input = norm_layer(input)
+ input = self.proj(input)
+ if self.resampler is not None:
+ input = self.resampler(input)
+ if self.pool is not None:
+ input = self.pool(input)
+ if self.config.conv_pooling_stacks > 1:
+ input = torch.cat(torch.chunk(input, self.config.conv_pooling_stacks, dim=1), dim=2)
+ if self.sparse_autoencoder is not None:
+ input = self.sparse_autoencoder(input)
+ return input
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ if self.block_scaler is not None:
+ if self.block_scaler.dtype != torch.float32:
+ self.block_scaler.data = self.block_scaler.data.to(torch.float32)
+ return self
+
+ def post_weight_update(self):
+ # force block scaler to be mean of 1
+ pass
diff --git a/toolkit/models/zipper_resampler.py b/toolkit/models/zipper_resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..35f018b09bd49e802a9a26c225890412706bb1c8
--- /dev/null
+++ b/toolkit/models/zipper_resampler.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+
+
+class ContextualAlphaMask(nn.Module):
+ def __init__(
+ self,
+ dim: int = 768,
+ ):
+ super(ContextualAlphaMask, self).__init__()
+ self.dim = dim
+
+ half_dim = dim // 2
+ quarter_dim = dim // 4
+
+ self.fc1 = nn.Linear(self.dim, self.dim)
+ self.fc2 = nn.Linear(self.dim, half_dim)
+ self.norm1 = nn.LayerNorm(half_dim)
+ self.fc3 = nn.Linear(half_dim, half_dim)
+ self.fc4 = nn.Linear(half_dim, quarter_dim)
+ self.norm2 = nn.LayerNorm(quarter_dim)
+ self.fc5 = nn.Linear(quarter_dim, quarter_dim)
+ self.fc6 = nn.Linear(quarter_dim, 1)
+ # set fc6 weights to near zero
+ self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ # x = (batch_size, 77, 768)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ x = self.norm1(x)
+ x = self.act_fn(x)
+ x = self.fc3(x)
+ x = self.act_fn(x)
+ x = self.fc4(x)
+ x = self.norm2(x)
+ x = self.act_fn(x)
+ x = self.fc5(x)
+ x = self.act_fn(x)
+ x = self.fc6(x)
+ x = torch.sigmoid(x)
+ return x
+
+
+class ZipperModule(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ in_tokens,
+ out_size,
+ out_tokens,
+ hidden_size,
+ hidden_tokens,
+ use_residual=False,
+ ):
+ super().__init__()
+ self.in_size = in_size
+ self.in_tokens = in_tokens
+ self.out_size = out_size
+ self.out_tokens = out_tokens
+ self.hidden_size = hidden_size
+ self.hidden_tokens = hidden_tokens
+ self.use_residual = use_residual
+
+ self.act_fn = nn.GELU()
+ self.layernorm = nn.LayerNorm(self.in_size)
+
+ self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1)
+ # act
+ self.fc1 = nn.Linear(self.in_size, self.hidden_size)
+ # act
+ self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1)
+ # act
+ self.fc2 = nn.Linear(self.hidden_size, self.out_size)
+
+ def forward(self, x):
+ residual = x
+ x = self.layernorm(x)
+ x = self.conv1(x)
+ x = self.act_fn(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.conv2(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ if self.use_residual:
+ x = x + residual
+ return x
+
+
+class ZipperResampler(nn.Module):
+ def __init__(
+ self,
+ in_size,
+ in_tokens,
+ out_size,
+ out_tokens,
+ hidden_size,
+ hidden_tokens,
+ num_blocks=1,
+ is_conv_input=False,
+ ):
+ super().__init__()
+ self.is_conv_input = is_conv_input
+
+ module_list = []
+ for i in range(num_blocks):
+
+ this_in_size = in_size
+ this_in_tokens = in_tokens
+ this_out_size = out_size
+ this_out_tokens = out_tokens
+ this_hidden_size = hidden_size
+ this_hidden_tokens = hidden_tokens
+ use_residual = False
+
+ # maintain middle sizes as hidden_size
+ if i == 0: # first block
+ this_in_size = in_size
+ this_in_tokens = in_tokens
+ if num_blocks == 1:
+ this_out_size = out_size
+ this_out_tokens = out_tokens
+ else:
+ this_out_size = hidden_size
+ this_out_tokens = hidden_tokens
+ elif i == num_blocks - 1: # last block
+ this_out_size = out_size
+ this_out_tokens = out_tokens
+ if num_blocks == 1:
+ this_in_size = in_size
+ this_in_tokens = in_tokens
+ else:
+ this_in_size = hidden_size
+ this_in_tokens = hidden_tokens
+ else: # middle blocks
+ this_out_size = hidden_size
+ this_out_tokens = hidden_tokens
+ this_in_size = hidden_size
+ this_in_tokens = hidden_tokens
+ use_residual = True
+
+ module_list.append(ZipperModule(
+ in_size=this_in_size,
+ in_tokens=this_in_tokens,
+ out_size=this_out_size,
+ out_tokens=this_out_tokens,
+ hidden_size=this_hidden_size,
+ hidden_tokens=this_hidden_tokens,
+ use_residual=use_residual
+ ))
+
+ self.blocks = nn.ModuleList(module_list)
+
+ self.ctx_alpha = ContextualAlphaMask(
+ dim=out_size,
+ )
+
+ def forward(self, x):
+ if self.is_conv_input:
+ # flatten
+ x = x.view(x.size(0), x.size(1), -1)
+ # rearrange to (batch, tokens, size)
+ x = x.permute(0, 2, 1)
+
+ for block in self.blocks:
+ x = block(x)
+ alpha = self.ctx_alpha(x)
+ return x * alpha
diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..37f7987e2868e545627206de8e2e2654884c12a4
--- /dev/null
+++ b/toolkit/network_mixins.py
@@ -0,0 +1,727 @@
+import json
+import os
+from collections import OrderedDict
+from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
+
+import torch
+from optimum.quanto import QTensor
+from torch import nn
+import weakref
+
+from tqdm import tqdm
+
+from toolkit.config_modules import NetworkConfig
+from toolkit.lorm import extract_conv, extract_linear, count_parameters
+from toolkit.metadata import add_model_hash_to_meta
+from toolkit.paths import KEYMAPS_ROOT
+from toolkit.saving import get_lora_keymap_from_model_keymap
+from optimum.quanto import QBytesTensor
+
+if TYPE_CHECKING:
+ from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
+ from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
+ from toolkit.stable_diffusion_model import StableDiffusion
+ from toolkit.models.DoRA import DoRAModule
+
+Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
+Module = Union['LoConSpecialModule', 'LoRAModule', 'DoRAModule']
+
+LINEAR_MODULES = [
+ 'Linear',
+ 'LoRACompatibleLinear',
+ 'QLinear'
+ # 'GroupNorm',
+]
+CONV_MODULES = [
+ 'Conv2d',
+ 'LoRACompatibleConv'
+]
+
+ExtractMode = Union[
+ 'existing'
+ 'fixed',
+ 'threshold',
+ 'ratio',
+ 'quantile',
+ 'percentage'
+]
+
+
+def broadcast_and_multiply(tensor, multiplier):
+ # Determine the number of dimensions required
+ num_extra_dims = tensor.dim() - multiplier.dim()
+
+ # Unsqueezing the tensor to match the dimensionality
+ for _ in range(num_extra_dims):
+ multiplier = multiplier.unsqueeze(-1)
+
+ try:
+ # Multiplying the broadcasted tensor with the output tensor
+ result = tensor * multiplier
+ except RuntimeError as e:
+ print(e)
+ print(tensor.size())
+ print(multiplier.size())
+ raise e
+
+ return result
+
+
+def add_bias(tensor, bias):
+ if bias is None:
+ return tensor
+ # add batch dim
+ bias = bias.unsqueeze(0)
+ bias = torch.cat([bias] * tensor.size(0), dim=0)
+ # Determine the number of dimensions required
+ num_extra_dims = tensor.dim() - bias.dim()
+
+ # Unsqueezing the tensor to match the dimensionality
+ for _ in range(num_extra_dims):
+ bias = bias.unsqueeze(-1)
+
+ # we may need to swap -1 for -2
+ if bias.size(1) != tensor.size(1):
+ if len(bias.size()) == 3:
+ bias = bias.permute(0, 2, 1)
+ elif len(bias.size()) == 4:
+ bias = bias.permute(0, 3, 1, 2)
+
+ # Multiplying the broadcasted tensor with the output tensor
+ try:
+ result = tensor + bias
+ except RuntimeError as e:
+ print(e)
+ print(tensor.size())
+ print(bias.size())
+ raise e
+
+ return result
+
+
+class ExtractableModuleMixin:
+ def extract_weight(
+ self: Module,
+ extract_mode: ExtractMode = "existing",
+ extract_mode_param: Union[int, float] = None,
+ ):
+ device = self.lora_down.weight.device
+ weight_to_extract = self.org_module[0].weight
+ if extract_mode == "existing":
+ extract_mode = 'fixed'
+ extract_mode_param = self.lora_dim
+
+ if isinstance(weight_to_extract, QBytesTensor):
+ weight_to_extract = weight_to_extract.dequantize()
+
+ weight_to_extract = weight_to_extract.clone().detach().float()
+
+ if self.org_module[0].__class__.__name__ in CONV_MODULES:
+ # do conv extraction
+ down_weight, up_weight, new_dim, diff = extract_conv(
+ weight=weight_to_extract,
+ mode=extract_mode,
+ mode_param=extract_mode_param,
+ device=device
+ )
+
+ elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
+ # do linear extraction
+ down_weight, up_weight, new_dim, diff = extract_linear(
+ weight=weight_to_extract,
+ mode=extract_mode,
+ mode_param=extract_mode_param,
+ device=device,
+ )
+ else:
+ raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
+
+ self.lora_dim = new_dim
+
+ # inject weights into the param
+ self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
+ self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
+
+ # copy bias if we have one and are using them
+ if self.org_module[0].bias is not None and self.lora_up.bias is not None:
+ self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
+
+ # set up alphas
+ self.alpha = (self.alpha * 0) + down_weight.shape[0]
+ self.scale = self.alpha / self.lora_dim
+
+ # assign them
+
+ # handle trainable scaler method locon does
+ if hasattr(self, 'scalar'):
+ # scaler is a parameter update the value with 1.0
+ self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
+
+
+class ToolkitModuleMixin:
+ def __init__(
+ self: Module,
+ *args,
+ network: Network,
+ **kwargs
+ ):
+ self.network_ref: weakref.ref = weakref.ref(network)
+ self.is_checkpointing = False
+ self._multiplier: Union[float, list, torch.Tensor] = None
+
+ def _call_forward(self: Module, x):
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return 0.0 # added to original forward
+
+ if hasattr(self, 'lora_mid') and self.lora_mid is not None:
+ lx = self.lora_mid(self.lora_down(x))
+ else:
+ try:
+ lx = self.lora_down(x)
+ except RuntimeError as e:
+ print(f"Error in {self.__class__.__name__} lora_down")
+ print(e)
+
+ if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
+ lx = self.dropout(lx)
+ # normal dropout
+ elif self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.rank_dropout > 0 and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ # handle trainable scaler method locon does
+ if hasattr(self, 'scalar'):
+ scale = scale * self.scalar
+
+ return lx * scale
+
+ def lorm_forward(self: Network, x, *args, **kwargs):
+ network: Network = self.network_ref()
+ if not network.is_active:
+ return self.org_forward(x, *args, **kwargs)
+
+ orig_dtype = x.dtype
+
+ if x.dtype != self.lora_down.weight.dtype:
+ x = x.to(self.lora_down.weight.dtype)
+
+ if network.lorm_train_mode == 'local':
+ # we are going to predict input with both and do a loss on them
+ inputs = x.detach()
+ with torch.no_grad():
+ # get the local prediction
+ target_pred = self.org_forward(inputs, *args, **kwargs).detach()
+ with torch.set_grad_enabled(True):
+ # make a prediction with the lorm
+ lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
+
+ local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
+ # backpropr
+ local_loss.backward()
+
+ network.module_losses.append(local_loss.detach())
+ # return the original as we dont want our trainer to affect ones down the line
+ return target_pred
+
+ else:
+ x = self.lora_up(self.lora_down(x))
+ if x.dtype != orig_dtype:
+ x = x.to(orig_dtype)
+
+ def forward(self: Module, x, *args, **kwargs):
+ skip = False
+ network: Network = self.network_ref()
+ if network.is_lorm:
+ # we are doing lorm
+ return self.lorm_forward(x, *args, **kwargs)
+
+ # skip if not active
+ if not network.is_active:
+ skip = True
+
+ # skip if is merged in
+ if network.is_merged_in:
+ skip = True
+
+ # skip if multiplier is 0
+ if network._multiplier == 0:
+ skip = True
+
+ if skip:
+ # network is not active, avoid doing anything
+ return self.org_forward(x, *args, **kwargs)
+
+ # if self.__class__.__name__ == "DoRAModule":
+ # # return dora forward
+ # return self.dora_forward(x, *args, **kwargs)
+
+ org_forwarded = self.org_forward(x, *args, **kwargs)
+
+ if isinstance(x, QTensor):
+ x = x.dequantize()
+ # always cast to float32
+ lora_input = x.to(self.lora_down.weight.dtype)
+ lora_output = self._call_forward(lora_input)
+ multiplier = self.network_ref().torch_multiplier
+
+ lora_output_batch_size = lora_output.size(0)
+ multiplier_batch_size = multiplier.size(0)
+ if lora_output_batch_size != multiplier_batch_size:
+ num_interleaves = lora_output_batch_size // multiplier_batch_size
+ # todo check if this is correct, do we just concat when doing cfg?
+ multiplier = multiplier.repeat_interleave(num_interleaves)
+
+ scaled_lora_output = broadcast_and_multiply(lora_output, multiplier)
+ scaled_lora_output = scaled_lora_output.to(org_forwarded.dtype)
+
+ if self.__class__.__name__ == "DoRAModule":
+ # ref https://github.com/huggingface/peft/blob/1e6d1d73a0850223b0916052fd8d2382a90eae5a/src/peft/tuners/lora/layer.py#L417
+ # x = dropout(x)
+ # todo this wont match the dropout applied to the lora
+ if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
+ lx = self.dropout(x)
+ # normal dropout
+ elif self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(x, p=self.dropout)
+ else:
+ lx = x
+ lora_weight = self.lora_up.weight @ self.lora_down.weight
+ # scale it here
+ # todo handle our batch split scalers for slider training. For now take the mean of them
+ scale = multiplier.mean()
+ scaled_lora_weight = lora_weight * scale
+ scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype)
+
+ try:
+ x = org_forwarded + scaled_lora_output
+ except RuntimeError as e:
+ print(e)
+ print(org_forwarded.size())
+ print(scaled_lora_output.size())
+ raise e
+ return x
+
+ def enable_gradient_checkpointing(self: Module):
+ self.is_checkpointing = True
+
+ def disable_gradient_checkpointing(self: Module):
+ self.is_checkpointing = False
+
+ @torch.no_grad()
+ def merge_out(self: Module, merge_out_weight=1.0):
+ # make sure it is positive
+ merge_out_weight = abs(merge_out_weight)
+ # merging out is just merging in the negative of the weight
+ self.merge_in(merge_weight=-merge_out_weight)
+
+ @torch.no_grad()
+ def merge_in(self: Module, merge_weight=1.0):
+ if not self.can_merge_in:
+ return
+ # get up/down weight
+ up_weight = self.lora_up.weight.clone().float()
+ down_weight = self.lora_down.weight.clone().float()
+
+ # extract weight from org_module
+ org_sd = self.org_module[0].state_dict()
+ # todo find a way to merge in weights when doing quantized model
+ if 'weight._data' in org_sd:
+ # quantized weight
+ return
+
+ weight_key = "weight"
+ if 'weight._data' in org_sd:
+ # quantized weight
+ weight_key = "weight._data"
+
+ orig_dtype = org_sd[weight_key].dtype
+ weight = org_sd[weight_key].float()
+
+ multiplier = merge_weight
+ scale = self.scale
+ # handle trainable scaler method locon does
+ if hasattr(self, 'scalar'):
+ scale = scale * self.scalar
+
+ # merge weight
+ if len(weight.size()) == 2:
+ # linear
+ weight = weight + multiplier * (up_weight @ down_weight) * scale
+ elif down_weight.size()[2:4] == (1, 1):
+ # conv2d 1x1
+ weight = (
+ weight
+ + multiplier
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ * scale
+ )
+ else:
+ # conv2d 3x3
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
+ # print(conved.size(), weight.size(), module.stride, module.padding)
+ weight = weight + multiplier * conved * scale
+
+ # set weight to org_module
+ org_sd[weight_key] = weight.to(orig_dtype)
+ self.org_module[0].load_state_dict(org_sd)
+
+ def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
+ # LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
+ # outputs the same. It is basically a LoRA but with the original module removed
+
+ # if a state dict is passed, use those weights instead of extracting
+ # todo load from state dict
+ network: Network = self.network_ref()
+ lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
+
+ extract_mode = lorm_config.extract_mode
+ extract_mode_param = lorm_config.extract_mode_param
+ parameter_threshold = lorm_config.parameter_threshold
+ self.extract_weight(
+ extract_mode=extract_mode,
+ extract_mode_param=extract_mode_param
+ )
+
+
+class ToolkitNetworkMixin:
+ def __init__(
+ self: Network,
+ *args,
+ train_text_encoder: Optional[bool] = True,
+ train_unet: Optional[bool] = True,
+ is_sdxl=False,
+ is_v2=False,
+ is_ssd=False,
+ is_vega=False,
+ network_config: Optional[NetworkConfig] = None,
+ is_lorm=False,
+ **kwargs
+ ):
+ self.train_text_encoder = train_text_encoder
+ self.train_unet = train_unet
+ self.is_checkpointing = False
+ self._multiplier: float = 1.0
+ self.is_active: bool = False
+ self.is_sdxl = is_sdxl
+ self.is_ssd = is_ssd
+ self.is_vega = is_vega
+ self.is_v2 = is_v2
+ self.is_v1 = not is_v2 and not is_sdxl and not is_ssd and not is_vega
+ self.is_merged_in = False
+ self.is_lorm = is_lorm
+ self.network_config: NetworkConfig = network_config
+ self.module_losses: List[torch.Tensor] = []
+ self.lorm_train_mode: Literal['local', None] = None
+ self.can_merge_in = not is_lorm
+
+ def get_keymap(self: Network, force_weight_mapping=False):
+ use_weight_mapping = False
+
+ if self.is_ssd:
+ keymap_tail = 'ssd'
+ use_weight_mapping = True
+ elif self.is_vega:
+ keymap_tail = 'vega'
+ use_weight_mapping = True
+ elif self.is_sdxl:
+ keymap_tail = 'sdxl'
+ elif self.is_v2:
+ keymap_tail = 'sd2'
+ else:
+ keymap_tail = 'sd1'
+ # todo double check this
+ # use_weight_mapping = True
+
+ if force_weight_mapping:
+ use_weight_mapping = True
+
+ # load keymap
+ keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
+ if use_weight_mapping:
+ keymap_name = f"stable_diffusion_{keymap_tail}.json"
+
+ keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
+
+ keymap = None
+ # check if file exists
+ if os.path.exists(keymap_path):
+ with open(keymap_path, 'r') as f:
+ keymap = json.load(f)['ldm_diffusers_keymap']
+
+ if use_weight_mapping and keymap is not None:
+ # get keymap from weights
+ keymap = get_lora_keymap_from_model_keymap(keymap)
+
+ # upgrade keymaps for DoRA
+ if self.network_type.lower() == 'dora':
+ if keymap is not None:
+ new_keymap = {}
+ for ldm_key, diffusers_key in keymap.items():
+ ldm_key = ldm_key.replace('.alpha', '.magnitude')
+ # ldm_key = ldm_key.replace('.lora_down.weight', '.lora_down')
+ # ldm_key = ldm_key.replace('.lora_up.weight', '.lora_up')
+
+ diffusers_key = diffusers_key.replace('.alpha', '.magnitude')
+ # diffusers_key = diffusers_key.replace('.lora_down.weight', '.lora_down')
+ # diffusers_key = diffusers_key.replace('.lora_up.weight', '.lora_up')
+
+ new_keymap[ldm_key] = diffusers_key
+
+ keymap = new_keymap
+
+ return keymap
+
+ def save_weights(
+ self: Network,
+ file, dtype=torch.float16,
+ metadata=None,
+ extra_state_dict: Optional[OrderedDict] = None
+ ):
+ keymap = self.get_keymap()
+
+ save_keymap = {}
+ if keymap is not None:
+ for ldm_key, diffusers_key in keymap.items():
+ # invert them
+ save_keymap[diffusers_key] = ldm_key
+
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+ save_dict = OrderedDict()
+
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ save_key = save_keymap[key] if key in save_keymap else key
+ save_dict[save_key] = v
+ del state_dict[key]
+
+ if extra_state_dict is not None:
+ # add extra items to state dict
+ for key in list(extra_state_dict.keys()):
+ v = extra_state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ save_dict[key] = v
+
+ if self.peft_format:
+ # lora_down = lora_A
+ # lora_up = lora_B
+ # no alpha
+
+ new_save_dict = {}
+ for key, value in save_dict.items():
+ if key.endswith('.alpha'):
+ continue
+ new_key = key
+ new_key = new_key.replace('lora_down', 'lora_A')
+ new_key = new_key.replace('lora_up', 'lora_B')
+ # replace all $$ with .
+ new_key = new_key.replace('$$', '.')
+ new_save_dict[new_key] = value
+
+ save_dict = new_save_dict
+
+ if metadata is None:
+ metadata = OrderedDict()
+ metadata = add_model_hash_to_meta(state_dict, metadata)
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+ save_file(save_dict, file, metadata)
+ else:
+ torch.save(save_dict, file)
+
+ def load_weights(self: Network, file, force_weight_mapping=False):
+ # allows us to save and load to and from ldm weights
+ keymap = self.get_keymap(force_weight_mapping)
+ keymap = {} if keymap is None else keymap
+
+ if isinstance(file, str):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+ else:
+ # probably a state dict
+ weights_sd = file
+
+ load_sd = OrderedDict()
+ for key, value in weights_sd.items():
+ load_key = keymap[key] if key in keymap else key
+ # replace old double __ with single _
+ if self.is_pixart:
+ load_key = load_key.replace('__', '_')
+
+ if self.peft_format:
+ # lora_down = lora_A
+ # lora_up = lora_B
+ # no alpha
+ if load_key.endswith('.alpha'):
+ continue
+ load_key = load_key.replace('lora_A', 'lora_down')
+ load_key = load_key.replace('lora_B', 'lora_up')
+ # replace all . with $$
+ load_key = load_key.replace('.', '$$')
+ load_key = load_key.replace('$$lora_down$$', '.lora_down.')
+ load_key = load_key.replace('$$lora_up$$', '.lora_up.')
+
+ load_sd[load_key] = value
+
+ # extract extra items from state dict
+ current_state_dict = self.state_dict()
+ extra_dict = OrderedDict()
+ to_delete = []
+ for key in list(load_sd.keys()):
+ if key not in current_state_dict:
+ extra_dict[key] = load_sd[key]
+ to_delete.append(key)
+ for key in to_delete:
+ del load_sd[key]
+
+ print(f"Missing keys: {to_delete}")
+ if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (
+ len(to_delete) == 1 and 'emb_params' in to_delete):
+ print(" Attempting to load with forced keymap")
+ return self.load_weights(file, force_weight_mapping=True)
+
+ info = self.load_state_dict(load_sd, False)
+ if len(extra_dict.keys()) == 0:
+ extra_dict = None
+ return extra_dict
+
+ @torch.no_grad()
+ def _update_torch_multiplier(self: Network):
+ # builds a tensor for fast usage in the forward pass of the network modules
+ # without having to set it in every single module every time it changes
+ multiplier = self._multiplier
+ # get first module
+ first_module = self.get_all_modules()[0]
+ device = first_module.lora_down.weight.device
+ dtype = first_module.lora_down.weight.dtype
+ with torch.no_grad():
+ tensor_multiplier = None
+ if isinstance(multiplier, int) or isinstance(multiplier, float):
+ tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
+ elif isinstance(multiplier, list):
+ tensor_multiplier = torch.tensor(multiplier).to(device, dtype=dtype)
+ elif isinstance(multiplier, torch.Tensor):
+ tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
+
+ self.torch_multiplier = tensor_multiplier.clone().detach()
+
+ @property
+ def multiplier(self) -> Union[float, List[float], List[List[float]]]:
+ return self._multiplier
+
+ @multiplier.setter
+ def multiplier(self, value: Union[float, List[float], List[List[float]]]):
+ # it takes time to update all the multipliers, so we only do it if the value has changed
+ if self._multiplier == value:
+ return
+ # if we are setting a single value but have a list, keep the list if every item is the same as value
+ self._multiplier = value
+ self._update_torch_multiplier()
+
+ # called when the context manager is entered
+ # ie: with network:
+ def __enter__(self: Network):
+ self.is_active = True
+
+ def __exit__(self: Network, exc_type, exc_value, tb):
+ self.is_active = False
+
+ def force_to(self: Network, device, dtype):
+ self.to(device, dtype)
+ loras = []
+ if hasattr(self, 'unet_loras'):
+ loras += self.unet_loras
+ if hasattr(self, 'text_encoder_loras'):
+ loras += self.text_encoder_loras
+ for lora in loras:
+ lora.to(device, dtype)
+
+ def get_all_modules(self: Network) -> List[Module]:
+ loras = []
+ if hasattr(self, 'unet_loras'):
+ loras += self.unet_loras
+ if hasattr(self, 'text_encoder_loras'):
+ loras += self.text_encoder_loras
+ return loras
+
+ def _update_checkpointing(self: Network):
+ for module in self.get_all_modules():
+ if self.is_checkpointing:
+ module.enable_gradient_checkpointing()
+ else:
+ module.disable_gradient_checkpointing()
+
+ def enable_gradient_checkpointing(self: Network):
+ # not supported
+ self.is_checkpointing = True
+ self._update_checkpointing()
+
+ def disable_gradient_checkpointing(self: Network):
+ # not supported
+ self.is_checkpointing = False
+ self._update_checkpointing()
+
+ def merge_in(self, merge_weight=1.0):
+ if self.network_type.lower() == 'dora':
+ return
+ self.is_merged_in = True
+ for module in self.get_all_modules():
+ module.merge_in(merge_weight)
+
+ def merge_out(self: Network, merge_weight=1.0):
+ if not self.is_merged_in:
+ return
+ self.is_merged_in = False
+ for module in self.get_all_modules():
+ module.merge_out(merge_weight)
+
+ def extract_weight(
+ self: Network,
+ extract_mode: ExtractMode = "existing",
+ extract_mode_param: Union[int, float] = None,
+ ):
+ if extract_mode_param is None:
+ raise ValueError("extract_mode_param must be set")
+ for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
+ module.extract_weight(
+ extract_mode=extract_mode,
+ extract_mode_param=extract_mode_param
+ )
+
+ def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
+ for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
+ module.setup_lorm(state_dict=state_dict)
+
+ def calculate_lorem_parameter_reduction(self):
+ params_reduced = 0
+ for module in self.get_all_modules():
+ num_orig_module_params = count_parameters(module.org_module[0])
+ num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
+ params_reduced += (num_orig_module_params - num_lorem_params)
+
+ return params_reduced
diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a258ff3a4c37c576d2bf51787cff00af0bafa4
--- /dev/null
+++ b/toolkit/optimizer.py
@@ -0,0 +1,103 @@
+import torch
+from transformers import Adafactor, AdamW
+
+
+def get_optimizer(
+ params,
+ optimizer_type='adam',
+ learning_rate=1e-6,
+ optimizer_params=None
+):
+ if optimizer_params is None:
+ optimizer_params = {}
+ lower_type = optimizer_type.lower()
+ if lower_type.startswith("dadaptation"):
+ # dadaptation optimizer does not use standard learning rate. 1 is the default value
+ import dadaptation
+ print("Using DAdaptAdam optimizer")
+ use_lr = learning_rate
+ if use_lr < 0.1:
+ # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
+ use_lr = 1.0
+ if lower_type.endswith('lion'):
+ optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
+ elif lower_type.endswith('adam'):
+ optimizer = dadaptation.DAdaptLion(params, eps=1e-6, lr=use_lr, **optimizer_params)
+ elif lower_type == 'dadaptation':
+ # backwards compatibility
+ optimizer = dadaptation.DAdaptAdam(params, eps=1e-6, lr=use_lr, **optimizer_params)
+ # warn user that dadaptation is deprecated
+ print("WARNING: Dadaptation optimizer type has been changed to DadaptationAdam. Please update your config.")
+ elif lower_type.startswith("prodigy8bit"):
+ from toolkit.optimizers.prodigy_8bit import Prodigy8bit
+ print("Using Prodigy optimizer")
+ use_lr = learning_rate
+ if use_lr < 0.1:
+ # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
+ use_lr = 1.0
+
+ print(f"Using lr {use_lr}")
+ # let net be the neural network you want to train
+ # you can choose weight decay value based on your problem, 0 by default
+ optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params)
+ elif lower_type.startswith("prodigy"):
+ from prodigyopt import Prodigy
+
+ print("Using Prodigy optimizer")
+ use_lr = learning_rate
+ if use_lr < 0.1:
+ # dadaptation uses different lr that is values of 0.1 to 1.0. default to 1.0
+ use_lr = 1.0
+
+ print(f"Using lr {use_lr}")
+ # let net be the neural network you want to train
+ # you can choose weight decay value based on your problem, 0 by default
+ optimizer = Prodigy(params, lr=use_lr, eps=1e-6, **optimizer_params)
+ elif lower_type == "adam8":
+ from toolkit.optimizers.adam8bit import Adam8bit
+
+ optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
+ elif lower_type == "adamw8":
+ from toolkit.optimizers.adam8bit import Adam8bit
+
+ optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params)
+ elif lower_type.endswith("8bit"):
+ import bitsandbytes
+
+ if lower_type == "adam8bit":
+ return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
+ if lower_type == "ademamix8bit":
+ return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
+ elif lower_type == "adamw8bit":
+ return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params)
+ elif lower_type == "lion8bit":
+ return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params)
+ else:
+ raise ValueError(f'Unknown optimizer type {optimizer_type}')
+ elif lower_type == 'adam':
+ optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
+ elif lower_type == 'adamw':
+ optimizer = torch.optim.AdamW(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
+ elif lower_type == 'lion':
+ try:
+ from lion_pytorch import Lion
+ return Lion(params, lr=learning_rate, **optimizer_params)
+ except ImportError:
+ raise ImportError("Please install lion_pytorch to use Lion optimizer -> pip install lion-pytorch")
+ elif lower_type == 'adagrad':
+ optimizer = torch.optim.Adagrad(params, lr=float(learning_rate), **optimizer_params)
+ elif lower_type == 'adafactor':
+ from toolkit.optimizers.adafactor import Adafactor
+ if 'relative_step' not in optimizer_params:
+ optimizer_params['relative_step'] = False
+ if 'scale_parameter' not in optimizer_params:
+ optimizer_params['scale_parameter'] = False
+ if 'warmup_init' not in optimizer_params:
+ optimizer_params['warmup_init'] = False
+ optimizer = Adafactor(params, lr=float(learning_rate), eps=1e-6, **optimizer_params)
+ elif lower_type == 'automagic':
+ from toolkit.optimizers.automagic import Automagic
+ optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params)
+ else:
+ raise ValueError(f'Unknown optimizer type {optimizer_type}')
+ return optimizer
diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py
new file mode 100644
index 0000000000000000000000000000000000000000..00cf06ee4ab7afbd201414b2ca2a302ffe539c9f
--- /dev/null
+++ b/toolkit/optimizers/adafactor.py
@@ -0,0 +1,359 @@
+import math
+from typing import List
+import torch
+from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
+from optimum.quanto import QBytesTensor
+import random
+
+
+class Adafactor(torch.optim.Optimizer):
+ """
+ Adafactor implementation with stochastic rounding accumulation and stochastic rounding on apply.
+ Modified from transformers Adafactor implementation to support stochastic rounding accumulation and apply.
+
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
+
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
+ `relative_step=False`.
+
+ Arguments:
+ params (`Iterable[nn.parameter.Parameter]`):
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
+ lr (`float`, *optional*):
+ The external learning rate.
+ eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
+ Regularization constants for square gradient and parameter scale respectively
+ clip_threshold (`float`, *optional*, defaults to 1.0):
+ Threshold of root mean square of final gradient update
+ decay_rate (`float`, *optional*, defaults to -0.8):
+ Coefficient used to compute running averages of square
+ beta1 (`float`, *optional*):
+ Coefficient used for computing running averages of gradient
+ weight_decay (`float`, *optional*, defaults to 0.0):
+ Weight decay (L2 penalty)
+ scale_parameter (`bool`, *optional*, defaults to `True`):
+ If True, learning rate is scaled by root mean square
+ relative_step (`bool`, *optional*, defaults to `True`):
+ If True, time-dependent learning rate is computed instead of external learning rate
+ warmup_init (`bool`, *optional*, defaults to `False`):
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
+
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
+
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
+
+ - Training without LR warmup or clip_threshold is not recommended.
+
+ - use scheduled LR warm-up to fixed LR
+ - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
+ - Disable relative updates
+ - Use scale_parameter=False
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
+
+ Example:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
+ ```
+
+ Others reported the following combination to work well:
+
+ ```python
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ ```
+
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
+ scheduler as following:
+
+ ```python
+ from transformers.optimization import Adafactor, AdafactorSchedule
+
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+ lr_scheduler = AdafactorSchedule(optimizer)
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
+ ```
+
+ Usage:
+
+ ```python
+ # replace AdamW with Adafactor
+ optimizer = Adafactor(
+ model.parameters(),
+ lr=1e-3,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ relative_step=False,
+ scale_parameter=False,
+ warmup_init=False,
+ )
+ ```"""
+
+ def __init__(
+ self,
+ params,
+ lr=None,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ beta1=None,
+ weight_decay=0.0,
+ scale_parameter=True,
+ relative_step=True,
+ warmup_init=False,
+ do_paramiter_swapping=False,
+ paramiter_swapping_factor=0.1,
+ ):
+ if lr is not None and relative_step:
+ raise ValueError(
+ "Cannot combine manual `lr` and `relative_step=True` options")
+ if warmup_init and not relative_step:
+ raise ValueError(
+ "`warmup_init=True` requires `relative_step=True`")
+
+ defaults = {
+ "lr": lr,
+ "eps": eps,
+ "clip_threshold": clip_threshold,
+ "decay_rate": decay_rate,
+ "beta1": beta1,
+ "weight_decay": weight_decay,
+ "scale_parameter": scale_parameter,
+ "relative_step": relative_step,
+ "warmup_init": warmup_init,
+ }
+ super().__init__(params, defaults)
+
+ self.base_lrs: List[float] = [
+ lr for group in self.param_groups
+ ]
+
+ self.is_stochastic_rounding_accumulation = False
+
+ # setup stochastic grad accum hooks
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and param.dtype != torch.float32:
+ self.is_stochastic_rounding_accumulation = True
+ param.register_post_accumulate_grad_hook(
+ stochastic_grad_accummulation
+ )
+
+ self.do_paramiter_swapping = do_paramiter_swapping
+ self.paramiter_swapping_factor = paramiter_swapping_factor
+ self._total_paramiter_size = 0
+ # count total paramiters
+ for group in self.param_groups:
+ for param in group['params']:
+ self._total_paramiter_size += torch.numel(param)
+ # pretty print total paramiters with comma seperation
+ print(f"Total training paramiters: {self._total_paramiter_size:,}")
+
+ # needs to be enabled to count paramiters
+ if self.do_paramiter_swapping:
+ self.enable_paramiter_swapping(self.paramiter_swapping_factor)
+
+
+ def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1):
+ self.do_paramiter_swapping = True
+ self.paramiter_swapping_factor = paramiter_swapping_factor
+ # call it an initial time
+ self.swap_paramiters()
+
+ def swap_paramiters(self):
+ all_params = []
+ # deactivate all paramiters
+ for group in self.param_groups:
+ for param in group['params']:
+ param.requires_grad_(False)
+ # remove any grad
+ param.grad = None
+ all_params.append(param)
+ # shuffle all paramiters
+ random.shuffle(all_params)
+
+ # keep activating paramiters until we are going to go over the target paramiters
+ target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor)
+ total_paramiters = 0
+ for param in all_params:
+ total_paramiters += torch.numel(param)
+ if total_paramiters >= target_paramiters:
+ break
+ else:
+ param.requires_grad_(True)
+
+ @staticmethod
+ def _get_lr(param_group, param_state):
+ rel_step_sz = param_group["lr"]
+ if param_group["relative_step"]:
+ min_step = 1e-6 * \
+ param_state["step"] if param_group["warmup_init"] else 1e-2
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
+ param_scale = 1.0
+ if param_group["scale_parameter"]:
+ param_scale = max(param_group["eps"][1], param_state["RMS"])
+ return param_scale * rel_step_sz
+
+ @staticmethod
+ def _get_options(param_group, param_shape):
+ factored = len(param_shape) >= 2
+ use_first_moment = param_group["beta1"] is not None
+ return factored, use_first_moment
+
+ @staticmethod
+ def _rms(tensor):
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+ @staticmethod
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
+ # copy from fairseq's adafactor implementation:
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
+ 1, keepdim=True)).rsqrt_().unsqueeze(-1)
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ def step_hook(self):
+ if not self.is_stochastic_rounding_accumulation:
+ return
+ # copy over stochastically rounded grads
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and hasattr(param, "_accum_grad"):
+ param.grad = param._accum_grad
+ del param._accum_grad
+
+ # adafactor manages its own lr
+ def get_learning_rates(self):
+ lrs = [
+ self._get_lr(group, self.state[group["params"][0]])
+ for group in self.param_groups
+ if group["params"][0].grad is not None
+ ]
+ if len(lrs) == 0:
+ lrs = self.base_lrs # if called before stepping
+ return lrs
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ self.step_hook()
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None or not p.requires_grad:
+ continue
+
+ grad = p.grad
+ if grad.dtype != torch.float32:
+ grad = grad.to(torch.float32)
+ if grad.is_sparse:
+ raise RuntimeError(
+ "Adafactor does not support sparse gradients.")
+
+ # if p has atts _scale then it is quantized. We need to divide the grad by the scale
+ # if hasattr(p, "_scale"):
+ # grad = grad / p._scale
+
+ state = self.state[p]
+ grad_shape = grad.shape
+
+ factored, use_first_moment = self._get_options(
+ group, grad_shape)
+ # State Initialization
+ if len(state) == 0:
+ state["step"] = 0
+
+ if use_first_moment:
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(grad)
+ if factored:
+ state["exp_avg_sq_row"] = torch.zeros(
+ grad_shape[:-1]).to(grad)
+ state["exp_avg_sq_col"] = torch.zeros(
+ grad_shape[:-2] + grad_shape[-1:]).to(grad)
+ else:
+ state["exp_avg_sq"] = torch.zeros_like(grad)
+
+ state["RMS"] = 0
+ else:
+ if use_first_moment:
+ state["exp_avg"] = state["exp_avg"].to(grad)
+ if factored:
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
+ grad)
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
+ grad)
+ else:
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+
+ p_data_fp32 = p
+
+ if isinstance(p_data_fp32, QBytesTensor):
+ p_data_fp32 = p_data_fp32.dequantize()
+ if p.dtype != torch.float32:
+ p_data_fp32 = p_data_fp32.clone().float()
+
+ state["step"] += 1
+ state["RMS"] = self._rms(p_data_fp32)
+ lr = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+ eps = group["eps"]
+ if isinstance(eps, tuple) or isinstance(eps, list):
+ eps = eps[0]
+ update = (grad**2) + eps
+ if factored:
+ exp_avg_sq_row = state["exp_avg_sq_row"]
+ exp_avg_sq_col = state["exp_avg_sq_col"]
+
+ exp_avg_sq_row.mul_(beta2t).add_(
+ update.mean(dim=-1), alpha=(1.0 - beta2t))
+ exp_avg_sq_col.mul_(beta2t).add_(
+ update.mean(dim=-2), alpha=(1.0 - beta2t))
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(
+ exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state["exp_avg_sq"]
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_(
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
+ update.mul_(lr)
+
+ if use_first_moment:
+ exp_avg = state["exp_avg"]
+ exp_avg.mul_(group["beta1"]).add_(
+ update, alpha=(1 - group["beta1"]))
+ update = exp_avg
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=(-group["weight_decay"] * lr))
+
+ p_data_fp32.add_(-update)
+
+ if p.dtype != torch.float32:
+ # apply stochastic rounding
+ copy_stochastic(p, p_data_fp32)
+
+ return loss
diff --git a/toolkit/optimizers/adam8bit.py b/toolkit/optimizers/adam8bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5fc976bf456c6f93e26776a614e150450e1875e
--- /dev/null
+++ b/toolkit/optimizers/adam8bit.py
@@ -0,0 +1,162 @@
+import math
+import torch
+from torch.optim import Optimizer
+from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation
+
+class Adam8bit(Optimizer):
+ """
+ Implements Adam optimizer with 8-bit state storage and stochastic rounding.
+
+ Arguments:
+ params (iterable): Iterable of parameters to optimize or dicts defining parameter groups
+ lr (float): Learning rate (default: 1e-3)
+ betas (tuple): Coefficients for computing running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float): Term added to denominator to improve numerical stability (default: 1e-8)
+ weight_decay (float): Weight decay coefficient (default: 0)
+ decouple (bool): Use AdamW style decoupled weight decay (default: True)
+ """
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, decouple=True):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= eps:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
+ decouple=decouple)
+ super(Adam8bit, self).__init__(params, defaults)
+
+ self.is_stochastic_rounding_accumulation = False
+
+ # Setup stochastic grad accumulation hooks
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and param.dtype != torch.float32:
+ self.is_stochastic_rounding_accumulation = True
+ param.register_post_accumulate_grad_hook(
+ stochastic_grad_accummulation
+ )
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return False
+
+ @property
+ def supports_flat_params(self):
+ return True
+
+ def step_hook(self):
+ if not self.is_stochastic_rounding_accumulation:
+ return
+ # Copy over stochastically rounded grads
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and hasattr(param, "_accum_grad"):
+ param.grad = param._accum_grad
+ del param._accum_grad
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ """
+ # Call pre step
+ self.step_hook()
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ beta1, beta2 = group['betas']
+ eps = group['eps']
+ lr = group['lr']
+ decay = group['weight_decay']
+ decouple = group['decouple']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ grad = p.grad.data.to(torch.float32)
+ p_fp32 = p.clone().to(torch.float32)
+
+ # Apply weight decay (coupled variant)
+ if decay != 0 and not decouple:
+ grad.add_(p_fp32.data, alpha=decay)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = Auto8bitTensor(
+ torch.zeros_like(p_fp32.data).detach())
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = Auto8bitTensor(
+ torch.zeros_like(p_fp32.data).detach())
+
+ exp_avg = state['exp_avg'].to(torch.float32)
+ exp_avg_sq = state['exp_avg_sq'].to(torch.float32)
+
+ state['step'] += 1
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ # Adam EMA updates
+ exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
+
+ # Apply weight decay (decoupled variant)
+ if decay != 0 and decouple:
+ p_fp32.data.mul_(1 - lr * decay)
+
+ # Bias correction
+ step_size = lr / bias_correction1
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+
+ # Take step
+ p_fp32.data.addcdiv_(exp_avg, denom, value=-step_size)
+
+ # Update state with stochastic rounding
+ state['exp_avg'] = Auto8bitTensor(exp_avg)
+ state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq)
+
+ # Apply stochastic rounding to parameters
+ copy_stochastic(p.data, p_fp32.data)
+
+ return loss
+
+ def state_dict(self):
+ """Returns the state of the optimizer as a dict."""
+ state_dict = super().state_dict()
+
+ # Convert Auto8bitTensor objects to regular state dicts
+ for param_id, param_state in state_dict['state'].items():
+ for key, value in param_state.items():
+ if isinstance(value, Auto8bitTensor):
+ param_state[key] = {
+ '_type': 'Auto8bitTensor',
+ 'state': value.state_dict()
+ }
+
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ """Loads the optimizer state."""
+ # First, load the basic state
+ super().load_state_dict(state_dict)
+
+ # Then convert any Auto8bitTensor states back to objects
+ for param_id, param_state in self.state.items():
+ for key, value in param_state.items():
+ if isinstance(value, dict) and value.get('_type') == 'Auto8bitTensor':
+ param_state[key] = Auto8bitTensor(value['state'])
+
diff --git a/toolkit/optimizers/automagic.py b/toolkit/optimizers/automagic.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7355f168aa03d77a01124c5cb42f1ea625a61a
--- /dev/null
+++ b/toolkit/optimizers/automagic.py
@@ -0,0 +1,335 @@
+from collections import OrderedDict
+import math
+from typing import List
+import torch
+from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation
+from optimum.quanto import QBytesTensor
+import random
+
+
+class Automagic(torch.optim.Optimizer):
+ def __init__(
+ self,
+ params,
+ lr=None,
+ min_lr=1e-7,
+ max_lr=1e-3,
+ lr_pump_scale=1.1,
+ lr_dump_scale=0.85,
+ eps=(1e-30, 1e-3),
+ clip_threshold=1.0,
+ decay_rate=-0.8,
+ weight_decay=0.0,
+ do_paramiter_swapping=False,
+ paramiter_swapping_factor=0.1,
+ ):
+ self.lr = lr
+ self.min_lr = min_lr
+ self.max_lr = max_lr
+ self.lr_pump_scale = lr_pump_scale
+ self.lr_dump_scale = lr_dump_scale
+
+ defaults = {
+ "lr": lr,
+ "eps": eps,
+ "clip_threshold": clip_threshold,
+ "decay_rate": decay_rate,
+ "weight_decay": weight_decay,
+ }
+ super().__init__(params, defaults)
+
+ self.base_lrs: List[float] = [
+ lr for group in self.param_groups
+ ]
+
+ self.is_stochastic_rounding_accumulation = False
+
+ # setup stochastic grad accum hooks
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and param.dtype != torch.float32:
+ self.is_stochastic_rounding_accumulation = True
+ param.register_post_accumulate_grad_hook(
+ stochastic_grad_accummulation
+ )
+
+ self.do_paramiter_swapping = do_paramiter_swapping
+ self.paramiter_swapping_factor = paramiter_swapping_factor
+ self._total_paramiter_size = 0
+ # count total paramiters
+ for group in self.param_groups:
+ for param in group['params']:
+ self._total_paramiter_size += torch.numel(param)
+ # pretty print total paramiters with comma seperation
+ print(f"Total training paramiters: {self._total_paramiter_size:,}")
+
+ # needs to be enabled to count paramiters
+ if self.do_paramiter_swapping:
+ self.enable_paramiter_swapping(self.paramiter_swapping_factor)
+
+ def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1):
+ self.do_paramiter_swapping = True
+ self.paramiter_swapping_factor = paramiter_swapping_factor
+ # call it an initial time
+ self.swap_paramiters()
+
+ def swap_paramiters(self):
+ all_params = []
+ # deactivate all paramiters
+ for group in self.param_groups:
+ for param in group['params']:
+ param.requires_grad_(False)
+ # remove any grad
+ param.grad = None
+ all_params.append(param)
+ # shuffle all paramiters
+ random.shuffle(all_params)
+
+ # keep activating paramiters until we are going to go over the target paramiters
+ target_paramiters = int(
+ self._total_paramiter_size * self.paramiter_swapping_factor)
+ total_paramiters = 0
+ for param in all_params:
+ total_paramiters += torch.numel(param)
+ if total_paramiters >= target_paramiters:
+ break
+ else:
+ param.requires_grad_(True)
+
+ @staticmethod
+ def _get_lr(param_group, param_state):
+ if 'avg_lr' in param_state:
+ lr = param_state["avg_lr"]
+ else:
+ lr = 0.0
+ return lr
+
+ def _get_group_lr(self, group):
+ group_lrs = []
+ for p in group["params"]:
+ group_lrs.append(self._get_lr(group, self.state[p]))
+ # return avg
+ if len(group_lrs) == 0:
+ return self.lr
+ return sum(group_lrs) / len(group_lrs)
+
+ @staticmethod
+ def _rms(tensor):
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+ @staticmethod
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
+ # copy from fairseq's adafactor implementation:
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-
+ 1, keepdim=True)).rsqrt_().unsqueeze(-1)
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ def step_hook(self):
+ if not self.is_stochastic_rounding_accumulation:
+ return
+ # copy over stochastically rounded grads
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and hasattr(param, "_accum_grad"):
+ param.grad = param._accum_grad
+ del param._accum_grad
+
+ # adafactor manages its own lr
+ def get_learning_rates(self):
+
+ lrs = [
+ self._get_group_lr(group)
+ for group in self.param_groups
+ ]
+ if len(lrs) == 0:
+ lrs = self.base_lrs # if called before stepping
+ return lrs
+
+ def get_avg_learning_rate(self):
+ lrs = self.get_learning_rates()
+ return sum(lrs) / len(lrs)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ self.step_hook()
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None or not p.requires_grad:
+ continue
+
+ grad = p.grad
+ if grad.dtype != torch.float32:
+ grad = grad.to(torch.float32)
+ if grad.is_sparse:
+ raise RuntimeError(
+ "Automagic does not support sparse gradients.")
+
+ state = self.state[p]
+ grad_shape = grad.shape
+
+ factored = len(grad_shape) >= 2
+ # State Initialization
+ if len(state) == 0:
+ self.initialize_state(p)
+ else:
+ if factored:
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(
+ grad)
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(
+ grad)
+ else:
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+
+ p_data_fp32 = p
+
+ if isinstance(p_data_fp32, QBytesTensor):
+ p_data_fp32 = p_data_fp32.dequantize()
+ if p.dtype != torch.float32:
+ p_data_fp32 = p_data_fp32.clone().float()
+
+ state["step"] += 1
+ state["RMS"] = self._rms(p_data_fp32)
+ # lr = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+ eps = group["eps"]
+ if isinstance(eps, tuple) or isinstance(eps, list):
+ eps = eps[0]
+ update = (grad**2) + eps
+ if factored:
+ exp_avg_sq_row = state["exp_avg_sq_row"]
+ exp_avg_sq_col = state["exp_avg_sq_col"]
+
+ exp_avg_sq_row.mul_(beta2t).add_(
+ update.mean(dim=-1), alpha=(1.0 - beta2t))
+ exp_avg_sq_col.mul_(beta2t).add_(
+ update.mean(dim=-2), alpha=(1.0 - beta2t))
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(
+ exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state["exp_avg_sq"]
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_(
+ (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
+
+ # calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease
+ # update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr,
+ # for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter.
+ # to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum.
+
+ # not doing it this way anymore
+ # update.mul_(lr)
+
+ # Get signs of current last update and updates
+ last_polarity = state['last_polarity']
+ current_polarity = (update > 0).to(torch.bool)
+ sign_agreement = torch.where(
+ last_polarity == current_polarity, 1, -1)
+ state['last_polarity'] = current_polarity
+
+ lr_mask = state['lr_mask'].to(torch.float32)
+
+ # Update learning rate mask based on sign agreement
+ new_lr = torch.where(
+ sign_agreement > 0,
+ lr_mask * self.lr_pump_scale, # Increase lr
+ lr_mask * self.lr_dump_scale # Decrease lr
+ )
+
+ # Clip learning rates to bounds
+ new_lr = torch.clamp(
+ new_lr,
+ min=self.min_lr,
+ max=self.max_lr
+ )
+
+ # Apply the learning rate mask to the update
+ update.mul_(new_lr)
+
+ state['lr_mask'] = Auto8bitTensor(new_lr)
+ state['avg_lr'] = torch.mean(new_lr)
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=(-group["weight_decay"] * new_lr))
+
+ p_data_fp32.add_(-update)
+
+ if p.dtype != torch.float32:
+ # apply stochastic rounding
+ copy_stochastic(p, p_data_fp32)
+
+ return loss
+
+ def initialize_state(self, p):
+ state = self.state[p]
+ state["step"] = 0
+
+ # store the lr mask
+ if 'lr_mask' not in state:
+ state['lr_mask'] = Auto8bitTensor(torch.ones(
+ p.shape).to(p.device, dtype=torch.float32) * self.lr
+ )
+ state['avg_lr'] = torch.mean(
+ state['lr_mask'].to(torch.float32))
+ if 'last_polarity' not in state:
+ state['last_polarity'] = torch.zeros(
+ p.shape, dtype=torch.bool, device=p.device)
+
+ factored = len(p.shape) >= 2
+ if factored:
+ state["exp_avg_sq_row"] = torch.zeros(
+ p.shape[:-1]).to(p)
+ state["exp_avg_sq_col"] = torch.zeros(
+ p.shape[:-2] + p.shape[-1:]).to(p)
+ else:
+ state["exp_avg_sq"] = torch.zeros_like(p)
+
+ state["RMS"] = 0
+
+ # override the state_dict to save the lr_mask
+ def state_dict(self, *args, **kwargs):
+ orig_state_dict = super().state_dict(*args, **kwargs)
+ # convert the state to quantized tensor to scale and quantized
+ new_sace_state = {}
+ for p, state in orig_state_dict['state'].items():
+ save_state = {k: v for k, v in state.items() if k != 'lr_mask'}
+ save_state['lr_mask'] = state['lr_mask'].state_dict()
+ new_sace_state[p] = save_state
+
+ orig_state_dict['state'] = new_sace_state
+
+ return orig_state_dict
+
+ def load_state_dict(self, state_dict, strict=True):
+ # load the lr_mask from the state_dict
+ idx = 0
+ for group in self.param_groups:
+ for p in group['params']:
+ self.initialize_state(p)
+ state = self.state[p]
+ m = state_dict['state'][idx]['lr_mask']
+ sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale']
+ state['lr_mask'] = Auto8bitTensor(sd_mask)
+ del state_dict['state'][idx]['lr_mask']
+ idx += 1
+ super().load_state_dict(state_dict)
diff --git a/toolkit/optimizers/optimizer_utils.py b/toolkit/optimizers/optimizer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67991f21a5da53f22425dd8519eda3fc159bb93c
--- /dev/null
+++ b/toolkit/optimizers/optimizer_utils.py
@@ -0,0 +1,256 @@
+import torch
+from torch import Tensor
+from typing import Optional
+from optimum.quanto import QBytesTensor
+
+
+def compute_scale_for_dtype(tensor, dtype):
+ """
+ Compute appropriate scale for the given tensor and target dtype.
+
+ Args:
+ tensor: Input tensor to be quantized
+ dtype: Target dtype for quantization
+ Returns:
+ Appropriate scale factor for the quantization
+ """
+ if dtype == torch.int8:
+ abs_max = torch.max(torch.abs(tensor))
+ return abs_max / 127.0 if abs_max > 0 else 1.0
+ elif dtype == torch.uint8:
+ max_val = torch.max(tensor)
+ min_val = torch.min(tensor)
+ range_val = max_val - min_val
+ return range_val / 255.0 if range_val > 0 else 1.0
+ elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
+ # For float8, we typically want to preserve the magnitude of the values
+ # while fitting within the representable range of the format
+ abs_max = torch.max(torch.abs(tensor))
+ if dtype == torch.float8_e4m3fn:
+ # e4m3fn has range [-448, 448] with no infinities
+ max_representable = 448.0
+ else: # torch.float8_e5m2
+ # e5m2 has range [-57344, 57344] with infinities
+ max_representable = 57344.0
+
+ return abs_max / max_representable if abs_max > 0 else 1.0
+ else:
+ raise ValueError(f"Unsupported dtype for quantization: {dtype}")
+
+def quantize_tensor(tensor, dtype):
+ """
+ Quantize a floating-point tensor to the target dtype with appropriate scaling.
+
+ Args:
+ tensor: Input tensor (float)
+ dtype: Target dtype for quantization
+ Returns:
+ quantized_data: Quantized tensor
+ scale: Scale factor used
+ """
+ scale = compute_scale_for_dtype(tensor, dtype)
+
+ if dtype == torch.int8:
+ quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
+ elif dtype == torch.uint8:
+ quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
+ elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
+ # For float8, we scale and then cast directly to the target type
+ # The casting operation will handle the appropriate rounding
+ scaled_tensor = tensor / scale
+ quantized_data = scaled_tensor.to(dtype)
+ else:
+ raise ValueError(f"Unsupported dtype for quantization: {dtype}")
+
+ return quantized_data, scale
+
+
+def update_parameter(target, result_float):
+ """
+ Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
+ with proper rescaling for quantized tensors.
+
+ Args:
+ target: The parameter to update (either torch.Tensor or QBytesTensor)
+ result_float: The new values to assign (torch.Tensor)
+ """
+ if isinstance(target, QBytesTensor):
+ # Get the target dtype from the existing quantized tensor
+ target_dtype = target._data.dtype
+
+ # Handle device placement
+ device = target._data.device
+ result_float = result_float.to(device)
+
+ # Compute new quantized values and scale
+ quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
+
+ # Update the internal tensors with newly computed values
+ target._data.copy_(quantized_data)
+ target._scale.copy_(new_scale)
+ else:
+ # Regular tensor update
+ target.copy_(result_float)
+
+
+def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
+ """
+ Returns (mantissa_bits, total_bits) for each format.
+ mantissa_bits excludes the implicit leading 1.
+ """
+ if dtype == torch.float32:
+ return 23, 32
+ elif dtype == torch.bfloat16:
+ return 7, 16
+ elif dtype == torch.float16:
+ return 10, 16
+ elif dtype == torch.float8_e4m3fn:
+ return 3, 8
+ elif dtype == torch.float8_e5m2:
+ return 2, 8
+ elif dtype == torch.int8:
+ return 0, 8 # Int8 doesn't have mantissa bits
+ else:
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+def copy_stochastic(
+ target: torch.Tensor,
+ source: torch.Tensor,
+ eps: Optional[float] = None
+) -> None:
+ """
+ Performs stochastic rounding from source tensor to target tensor.
+
+ Args:
+ target: Destination tensor (determines the target format)
+ source: Source tensor (typically float32)
+ eps: Optional minimum value for stochastic rounding (for numerical stability)
+ """
+ with torch.no_grad():
+ # If target is float32, just copy directly
+ if target.dtype == torch.float32:
+ target.copy_(source)
+ return
+
+ # Special handling for int8
+ if target.dtype == torch.int8:
+ # Scale the source values to utilize the full int8 range
+ scaled = source * 127.0 # Scale to [-127, 127]
+
+ # Add random noise for stochastic rounding
+ noise = torch.rand_like(scaled) - 0.5
+ rounded = torch.round(scaled + noise)
+
+ # Clamp to int8 range
+ clamped = torch.clamp(rounded, -127, 127)
+ target.copy_(clamped.to(torch.int8))
+ return
+
+ mantissa_bits, _ = get_format_params(target.dtype)
+
+ # Convert source to int32 view
+ source_int = source.view(dtype=torch.int32)
+
+ # Calculate number of bits to round
+ bits_to_round = 23 - mantissa_bits # 23 is float32 mantissa bits
+
+ # Create random integers for stochastic rounding
+ rand = torch.randint_like(
+ source,
+ dtype=torch.int32,
+ low=0,
+ high=(1 << bits_to_round),
+ )
+
+ # Add random values to the bits that will be rounded off
+ result = source_int.clone()
+ result.add_(rand)
+
+ # Mask to keep only the bits we want
+ # Create mask with 1s in positions we want to keep
+ mask = (-1) << bits_to_round
+ result.bitwise_and_(mask)
+
+ # Handle minimum value threshold if specified
+ if eps is not None:
+ eps_int = torch.tensor(
+ eps, dtype=torch.float32).view(dtype=torch.int32)
+ zero_mask = (result.abs() < eps_int)
+ result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int
+
+ # Convert back to float32 view
+ result_float = result.view(dtype=torch.float32)
+
+ # Special handling for float8 formats
+ if target.dtype == torch.float8_e4m3fn:
+ result_float.clamp_(-448.0, 448.0)
+ elif target.dtype == torch.float8_e5m2:
+ result_float.clamp_(-57344.0, 57344.0)
+
+ # Copy the result to the target tensor
+ update_parameter(target, result_float)
+ # target.copy_(result_float)
+ del result, rand, source_int
+
+
+class Auto8bitTensor:
+ def __init__(self, data: Tensor, *args, **kwargs):
+ if isinstance(data, dict): # Add constructor from state dict
+ self._load_from_state_dict(data)
+ else:
+ abs_max = data.abs().max().item()
+ scale = abs_max / 127.0 if abs_max > 0 else 1.0
+
+ self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
+ self.scale = scale
+ self.orig_dtype = data.dtype
+
+ def dequantize(self) -> Tensor:
+ return self.quantized.to(dtype=torch.float32) * self.scale
+
+ def to(self, *args, **kwargs):
+ # Handle the dtype argument whether it's positional or keyword
+ dtype = None
+ if args and isinstance(args[0], torch.dtype):
+ dtype = args[0]
+ args = args[1:]
+ elif 'dtype' in kwargs:
+ dtype = kwargs['dtype']
+ del kwargs['dtype']
+
+ if dtype is not None:
+ # First dequantize then convert to requested dtype
+ return self.dequantize().to(dtype=dtype, *args, **kwargs)
+
+ # If no dtype specified, just pass through to parent
+ return self.dequantize().to(*args, **kwargs)
+
+ def state_dict(self):
+ """Returns a dictionary containing the current state of the tensor."""
+ return {
+ 'quantized': self.quantized,
+ 'scale': self.scale,
+ 'orig_dtype': self.orig_dtype
+ }
+
+ def _load_from_state_dict(self, state_dict):
+ """Loads the tensor state from a state dictionary."""
+ self.quantized = state_dict['quantized']
+ self.scale = state_dict['scale']
+ self.orig_dtype = state_dict['orig_dtype']
+
+ def __str__(self):
+ return f"Auto8bitTensor({self.dequantize()})"
+
+
+def stochastic_grad_accummulation(param):
+ if hasattr(param, "_accum_grad"):
+ grad_fp32 = param._accum_grad.clone().to(torch.float32)
+ grad_fp32.add_(param.grad.to(torch.float32))
+ copy_stochastic(param._accum_grad, grad_fp32)
+ del grad_fp32
+ del param.grad
+ else:
+ param._accum_grad = param.grad.clone()
+ del param.grad
diff --git a/toolkit/optimizers/prodigy_8bit.py b/toolkit/optimizers/prodigy_8bit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7f09149583da67d8f4fbaea6051b0b6694e467
--- /dev/null
+++ b/toolkit/optimizers/prodigy_8bit.py
@@ -0,0 +1,286 @@
+import math
+import torch
+import torch.distributed as dist
+from torch.optim import Optimizer
+from toolkit.optimizers.optimizer_utils import copy_stochastic, Auto8bitTensor, stochastic_grad_accummulation
+
+
+class Prodigy8bit(Optimizer):
+ r"""
+ Implements Adam with Prodigy step-sizes.
+ Handles stochastic rounding for various precisions as well as stochastic gradient accumulation.
+ Stores state in 8bit for memory savings.
+ Leave LR set to 1 unless you encounter instability.
+
+ Arguments:
+ params (iterable):
+ Iterable of parameters to optimize or dicts defining parameter groups.
+ lr (float):
+ Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate.
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ beta3 (float):
+ coefficients for computing the Prodidy stepsize using running averages.
+ If set to None, uses the value of square root of beta2 (default: None).
+ eps (float):
+ Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
+ weight_decay (float):
+ Weight decay, i.e. a L2 penalty (default: 0).
+ decouple (boolean):
+ Use AdamW style decoupled weight decay
+ use_bias_correction (boolean):
+ Turn on Adam's bias correction. Off by default.
+ safeguard_warmup (boolean):
+ Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.
+ d0 (float):
+ Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
+ d_coef (float):
+ Coefficient in the expression for the estimate of d (default 1.0).
+ Values such as 0.5 and 2.0 typically work as well.
+ Changing this parameter is the preferred way to tune the method.
+ growth_rate (float):
+ prevent the D estimate from growing faster than this multiplicative rate.
+ Default is inf, for unrestricted. Values like 1.02 give a kind of learning
+ rate warmup effect.
+ fsdp_in_use (bool):
+ If you're using sharded parameters, this should be set to True. The optimizer
+ will attempt to auto-detect this, but if you're using an implementation other
+ than PyTorch's builtin version, the auto-detection won't work.
+ """
+
+ def __init__(self, params, lr=1.0,
+ betas=(0.9, 0.999), beta3=None,
+ eps=1e-8, weight_decay=0, decouple=True,
+ use_bias_correction=False, safeguard_warmup=False,
+ d0=1e-6, d_coef=1.0, growth_rate=float('inf'),
+ fsdp_in_use=False):
+ if not 0.0 < d0:
+ raise ValueError("Invalid d0 value: {}".format(d0))
+ if not 0.0 < lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 < eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(
+ "Invalid beta parameter at index 1: {}".format(betas[1]))
+
+ if decouple and weight_decay > 0:
+ print(f"Using decoupled weight decay")
+
+ defaults = dict(lr=lr, betas=betas, beta3=beta3,
+ eps=eps, weight_decay=weight_decay,
+ d=d0, d0=d0, d_max=d0,
+ d_numerator=0.0, d_coef=d_coef,
+ k=0, growth_rate=growth_rate,
+ use_bias_correction=use_bias_correction,
+ decouple=decouple, safeguard_warmup=safeguard_warmup,
+ fsdp_in_use=fsdp_in_use)
+ self.d0 = d0
+ super(Prodigy8bit, self).__init__(params, defaults)
+
+ self.is_stochastic_rounding_accumulation = False
+
+ # setup stochastic grad accum hooks
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and param.dtype != torch.float32:
+ self.is_stochastic_rounding_accumulation = True
+ param.register_post_accumulate_grad_hook(
+ stochastic_grad_accummulation
+ )
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return False
+
+ @property
+ def supports_flat_params(self):
+ return True
+
+ def step_hook(self):
+ if not self.is_stochastic_rounding_accumulation:
+ return
+ # copy over stochastically rounded grads
+ for group in self.param_groups:
+ for param in group['params']:
+ if param.requires_grad and hasattr(param, "_accum_grad"):
+ param.grad = param._accum_grad
+ del param._accum_grad
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ # call pre step
+ self.step_hook()
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ d_denom = 0.0
+
+ group = self.param_groups[0]
+ use_bias_correction = group['use_bias_correction']
+ beta1, beta2 = group['betas']
+ beta3 = group['beta3']
+ if beta3 is None:
+ beta3 = math.sqrt(beta2)
+ k = group['k']
+
+ d = group['d']
+ d_max = group['d_max']
+ d_coef = group['d_coef']
+ lr = max(group['lr'] for group in self.param_groups)
+
+ if use_bias_correction:
+ bias_correction = ((1 - beta2**(k+1))**0.5) / (1 - beta1**(k+1))
+ else:
+ bias_correction = 1
+
+ dlr = d*lr*bias_correction
+
+ growth_rate = group['growth_rate']
+ decouple = group['decouple']
+ fsdp_in_use = group['fsdp_in_use']
+
+ d_numerator = group['d_numerator']
+ d_numerator *= beta3
+
+ for group in self.param_groups:
+ decay = group['weight_decay']
+ k = group['k']
+ eps = group['eps']
+ group_lr = group['lr']
+ d0 = group['d0']
+ safeguard_warmup = group['safeguard_warmup']
+
+ if group_lr not in [lr, 0.0]:
+ raise RuntimeError(
+ f"Setting different lr values in different parameter groups is only supported for values of 0")
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ if hasattr(p, "_fsdp_flattened"):
+ fsdp_in_use = True
+
+ grad = p.grad.data.to(torch.float32)
+ p_fp32 = p.clone().to(torch.float32)
+
+ # Apply weight decay (coupled variant)
+ if decay != 0 and not decouple:
+ grad.add_(p_fp32.data, alpha=decay)
+
+ state = self.state[p]
+
+ # State initialization
+ if 'step' not in state:
+ state['step'] = 0
+ state['s'] = Auto8bitTensor(
+ torch.zeros_like(p_fp32.data).detach())
+ state['p0'] = Auto8bitTensor(p_fp32.detach().clone())
+ # Exponential moving average of gradient values
+ state['exp_avg'] = Auto8bitTensor(
+ torch.zeros_like(p_fp32.data).detach())
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = Auto8bitTensor(
+ torch.zeros_like(p_fp32.data).detach())
+
+ exp_avg = state['exp_avg'].to(torch.float32)
+ exp_avg_sq = state['exp_avg_sq'].to(torch.float32)
+
+ s = state['s'].to(torch.float32)
+ p0 = state['p0'].to(torch.float32)
+
+ if group_lr > 0.0:
+ # we use d / d0 instead of just d to avoid getting values that are too small
+ d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(),
+ (p0.data - p_fp32.data).flatten()).item()
+
+ # Adam EMA updates
+ exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1))
+ exp_avg_sq.mul_(beta2).addcmul_(
+ grad, grad, value=d * d * (1-beta2))
+
+ if safeguard_warmup:
+ s.mul_(beta3).add_(grad, alpha=((d / d0) * d))
+ else:
+ s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr))
+ d_denom += s.abs().sum().item()
+
+ # update state with stochastic rounding
+ state['exp_avg'] = Auto8bitTensor(exp_avg)
+ state['exp_avg_sq'] = Auto8bitTensor(exp_avg_sq)
+ state['s'] = Auto8bitTensor(s)
+ state['p0'] = Auto8bitTensor(p0)
+
+ d_hat = d
+
+ # if we have not done any progres, return
+ # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0)
+ if d_denom == 0:
+ return loss
+
+ if lr > 0.0:
+ if fsdp_in_use:
+ dist_tensor = torch.zeros(2).cuda()
+ dist_tensor[0] = d_numerator
+ dist_tensor[1] = d_denom
+ dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
+ global_d_numerator = dist_tensor[0]
+ global_d_denom = dist_tensor[1]
+ else:
+ global_d_numerator = d_numerator
+ global_d_denom = d_denom
+
+ d_hat = d_coef * global_d_numerator / global_d_denom
+ if d == group['d0']:
+ d = max(d, d_hat)
+ d_max = max(d_max, d_hat)
+ d = min(d_max, d * growth_rate)
+
+ for group in self.param_groups:
+ group['d_numerator'] = global_d_numerator
+ group['d_denom'] = global_d_denom
+ group['d'] = d
+ group['d_max'] = d_max
+ group['d_hat'] = d_hat
+
+ decay = group['weight_decay']
+ k = group['k']
+ eps = group['eps']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.to(torch.float32)
+ p_fp32 = p.clone().to(torch.float32)
+
+ state = self.state[p]
+
+ exp_avg = state['exp_avg'].to(torch.float32)
+ exp_avg_sq = state['exp_avg_sq'].to(torch.float32)
+
+ state['step'] += 1
+
+ denom = exp_avg_sq.sqrt().add_(d * eps)
+
+ # Apply weight decay (decoupled variant)
+ if decay != 0 and decouple:
+ p_fp32.data.add_(p_fp32.data, alpha=-decay * dlr)
+
+ # Take step
+ p_fp32.data.addcdiv_(exp_avg, denom, value=-dlr)
+ # apply stochastic rounding
+ copy_stochastic(p.data, p_fp32.data)
+
+ group['k'] = k + 1
+
+ return loss
diff --git a/toolkit/orig_configs/sd_xl_refiner.yaml b/toolkit/orig_configs/sd_xl_refiner.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cab5fe283d77bf86e0f29e99f3ed0d3c7d9c752f
--- /dev/null
+++ b/toolkit/orig_configs/sd_xl_refiner.yaml
@@ -0,0 +1,91 @@
+model:
+ target: sgm.models.diffusion.DiffusionEngine
+ params:
+ scale_factor: 0.13025
+ disable_first_stage_autocast: True
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
+
+ network_config:
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ adm_in_channels: 2560
+ num_classes: sequential
+ use_checkpoint: True
+ in_channels: 4
+ out_channels: 4
+ model_channels: 384
+ attention_resolutions: [4, 2]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 4
+ context_dim: [1280, 1280, 1280, 1280] # 1280
+ spatial_transformer_attn_type: softmax-xformers
+ legacy: False
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ # crossattn and vector cond
+ - is_trainable: False
+ input_key: txt
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
+ params:
+ arch: ViT-bigG-14
+ version: laion2b_s39b_b160k
+ legacy: False
+ freeze: True
+ layer: penultimate
+ always_return_pooled: True
+ # vector cond
+ - is_trainable: False
+ input_key: original_size_as_tuple
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: crop_coords_top_left
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by two
+ # vector cond
+ - is_trainable: False
+ input_key: aesthetic_score
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256 # multiplied by one
+
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: vanilla-xformers
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 4, 4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
diff --git a/toolkit/paths.py b/toolkit/paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..b926c82f13d36790b9b17de7355b3a0d6e1abcbd
--- /dev/null
+++ b/toolkit/paths.py
@@ -0,0 +1,22 @@
+import os
+
+TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
+SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
+REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")
+KEYMAPS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "keymaps")
+ORIG_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "orig_configs")
+DIFFUSERS_CONFIGS_ROOT = os.path.join(TOOLKIT_ROOT, "toolkit", "diffusers_configs")
+
+# check if ENV variable is set
+if 'MODELS_PATH' in os.environ:
+ MODELS_PATH = os.environ['MODELS_PATH']
+else:
+ MODELS_PATH = os.path.join(TOOLKIT_ROOT, "models")
+
+
+def get_path(path):
+ # we allow absolute paths, but if it is not absolute, we assume it is relative to the toolkit root
+ if not os.path.isabs(path):
+ path = os.path.join(TOOLKIT_ROOT, path)
+ return path
diff --git a/toolkit/photomaker.py b/toolkit/photomaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..8037969507854129cf342d8b3fae7a6d1ff7581e
--- /dev/null
+++ b/toolkit/photomaker.py
@@ -0,0 +1,144 @@
+# Merge image encoder and fuse module to create an ID Encoder
+# send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding
+
+import torch
+import torch.nn as nn
+from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection
+from transformers.models.clip.configuration_clip import CLIPVisionConfig
+from transformers import PretrainedConfig
+
+VISION_CONFIG_DICT = {
+ "hidden_size": 1024,
+ "intermediate_size": 4096,
+ "num_attention_heads": 16,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768
+}
+
+class MLP(nn.Module):
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
+ super().__init__()
+ if use_residual:
+ assert in_dim == out_dim
+ self.layernorm = nn.LayerNorm(in_dim)
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
+ self.use_residual = use_residual
+ self.act_fn = nn.GELU()
+
+ def forward(self, x):
+ residual = x
+ x = self.layernorm(x)
+ x = self.fc1(x)
+ x = self.act_fn(x)
+ x = self.fc2(x)
+ if self.use_residual:
+ x = x + residual
+ return x
+
+
+class FuseModule(nn.Module):
+ def __init__(self, embed_dim):
+ super().__init__()
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ def fuse_fn(self, prompt_embeds, id_embeds):
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
+ return stacked_id_embeds
+
+ def forward(
+ self,
+ prompt_embeds,
+ id_embeds,
+ class_tokens_mask,
+ ) -> torch.Tensor:
+ # id_embeds shape: [b, max_num_inputs, 1, 2048]
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
+ num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
+ batch_size, max_num_inputs = id_embeds.shape[:2]
+ # seq_length: 77
+ seq_length = prompt_embeds.shape[1]
+ # flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
+ flat_id_embeds = id_embeds.view(
+ -1, id_embeds.shape[-2], id_embeds.shape[-1]
+ )
+ # valid_id_mask [b*max_num_inputs]
+ valid_id_mask = (
+ torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
+ < num_inputs[:, None]
+ )
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
+
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
+ class_tokens_mask = class_tokens_mask.view(-1)
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
+ # slice out the image token embeddings
+ image_token_embeds = prompt_embeds[class_tokens_mask]
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
+ return updated_prompt_embeds
+
+class PhotoMakerIDEncoder(CLIPVisionModelWithProjection):
+ def __init__(self, config=None, *model_args, **model_kwargs):
+ if config is None:
+ config = CLIPVisionConfig(**VISION_CONFIG_DICT)
+ super().__init__(config, *model_args, **model_kwargs)
+ self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
+ self.fuse_module = FuseModule(2048)
+
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
+ b, num_inputs, c, h, w = id_pixel_values.shape
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
+
+ shared_id_embeds = self.vision_model(id_pixel_values)[1]
+ id_embeds = self.visual_projection(shared_id_embeds)
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
+
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
+
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
+ updated_prompt_embeds = self.fuse_module(
+ prompt_embeds, id_embeds, class_tokens_mask)
+
+ return updated_prompt_embeds
+
+
+class PhotoMakerCLIPEncoder(CLIPVisionModelWithProjection):
+ def __init__(self, config=None, *model_args, **model_kwargs):
+ if config is None:
+ config = CLIPVisionConfig(**VISION_CONFIG_DICT)
+ super().__init__(config, *model_args, **model_kwargs)
+ self.visual_projection_2 = nn.Linear(1024, 1280, bias=False)
+
+ def forward(self, id_pixel_values, do_projection2=True, output_full=False):
+ b, num_inputs, c, h, w = id_pixel_values.shape
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
+ # last_hidden_state, 1, 257, 1024
+ vision_output = self.vision_model(id_pixel_values, output_hidden_states=True)
+ shared_id_embeds = vision_output[1]
+ id_embeds = self.visual_projection(shared_id_embeds)
+
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
+
+ if do_projection2:
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
+
+ if output_full:
+ return id_embeds, vision_output
+ return id_embeds
+
+
+
+if __name__ == "__main__":
+ PhotoMakerIDEncoder()
\ No newline at end of file
diff --git a/toolkit/photomaker_pipeline.py b/toolkit/photomaker_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6437b648e91e5d4e70abbcf0995d76dc1b00f81
--- /dev/null
+++ b/toolkit/photomaker_pipeline.py
@@ -0,0 +1,491 @@
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
+from collections import OrderedDict
+import os
+import PIL
+import numpy as np
+
+import torch
+from torchvision import transforms as T
+
+from safetensors import safe_open
+from huggingface_hub.utils import validate_hf_hub_args
+from transformers import CLIPImageProcessor, CLIPTokenizer
+from diffusers import StableDiffusionXLPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.utils import (
+ _get_model_file,
+ is_transformers_available,
+ logging,
+)
+
+from .photomaker import PhotoMakerIDEncoder
+
+PipelineImageInput = Union[
+ PIL.Image.Image,
+ torch.FloatTensor,
+ List[PIL.Image.Image],
+ List[torch.FloatTensor],
+]
+
+
+class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
+ @validate_hf_hub_args
+ def load_photomaker_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ weight_name: str,
+ subfolder: str = '',
+ trigger_word: str = 'img',
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ weight_name (`str`):
+ The weight name NOT the path to the weight.
+
+ subfolder (`str`, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+
+ trigger_word (`str`, *optional*, defaults to `"img"`):
+ The trigger word is used to identify the position of class word in the text prompt,
+ and it is recommended not to set it as a common word.
+ This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation.
+ """
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"id_encoder": {}, "lora_weights": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("id_encoder."):
+ state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
+ elif key.startswith("lora_weights."):
+ state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if keys != ["id_encoder", "lora_weights"]:
+ raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.")
+
+ self.trigger_word = trigger_word
+ # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet
+ print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...")
+ id_encoder = PhotoMakerIDEncoder()
+ id_encoder.load_state_dict(state_dict["id_encoder"], strict=True)
+ id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype)
+ self.id_encoder = id_encoder
+ self.id_image_processor = CLIPImageProcessor()
+
+ # load lora into models
+ print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]")
+ self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
+
+ # Add trigger word token
+ if self.tokenizer is not None:
+ self.tokenizer.add_tokens([self.trigger_word], special_tokens=True)
+
+ self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True)
+
+ def encode_prompt_with_trigger_word(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ num_id_images: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ class_tokens_mask: Optional[torch.LongTensor] = None,
+ ):
+ device = device or self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Find the token id of the trigger word
+ image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word)
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ input_ids = tokenizer.encode(prompt) # TODO: batch encode
+ clean_index = 0
+ clean_input_ids = []
+ class_token_index = []
+ # Find out the corrresponding class word token based on the newly added trigger word token
+ for i, token_id in enumerate(input_ids):
+ if token_id == image_token_id:
+ class_token_index.append(clean_index - 1)
+ else:
+ clean_input_ids.append(token_id)
+ clean_index += 1
+
+ if len(class_token_index) != 1:
+ raise ValueError(
+ f"PhotoMaker currently does not support multiple trigger words in a single prompt.\
+ Trigger word: {self.trigger_word}, Prompt: {prompt}."
+ )
+ class_token_index = class_token_index[0]
+
+ # Expand the class word token and corresponding mask
+ class_token = clean_input_ids[class_token_index]
+ clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \
+ clean_input_ids[class_token_index + 1:]
+
+ # Truncation or padding
+ max_len = tokenizer.model_max_length
+ if len(clean_input_ids) > max_len:
+ clean_input_ids = clean_input_ids[:max_len]
+ else:
+ clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
+ max_len - len(clean_input_ids)
+ )
+
+ class_tokens_mask = [True if class_token_index <= i < class_token_index + num_id_images else False \
+ for i in range(len(clean_input_ids))]
+
+ clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0)
+ class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0)
+
+ prompt_embeds = text_encoder(
+ clean_input_ids.to(device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case
+
+ return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ # Added parameters (for PhotoMaker)
+ input_id_images: PipelineImageInput = None,
+ start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future
+ class_tokens_mask: Optional[torch.LongTensor] = None,
+ prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Only the parameters introduced by PhotoMaker are discussed here.
+ For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+
+ Args:
+ input_id_images (`PipelineImageInput`, *optional*):
+ Input ID Image to work with PhotoMaker.
+ class_tokens_mask (`torch.LongTensor`, *optional*):
+ Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word.
+ prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+ #
+ if prompt_embeds is not None and class_tokens_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`."
+ )
+ # check the input id images
+ if input_id_images is None:
+ raise ValueError(
+ "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline."
+ )
+ if not isinstance(input_id_images, list):
+ input_id_images = [input_id_images]
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ assert do_classifier_free_guidance
+
+ # 3. Encode input prompt
+ num_id_images = len(input_id_images)
+
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ class_tokens_mask,
+ ) = self.encode_prompt_with_trigger_word(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_id_images=num_id_images,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ class_tokens_mask=class_tokens_mask,
+ )
+
+ # 4. Encode input prompt without the trigger word for delayed conditioning
+ prompt_text_only = prompt.replace(" " + self.trigger_word, "") # sensitive to white space
+ (
+ prompt_embeds_text_only,
+ negative_prompt_embeds,
+ pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt_text_only,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds_text_only,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds_text_only,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+
+ # 5. Prepare the input ID images
+ dtype = next(self.id_encoder.parameters()).dtype
+ if not isinstance(input_id_images[0], torch.Tensor):
+ id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values
+
+ id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts
+
+ # 6. Get the update text embedding with the stacked ID embedding
+ prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ # 7. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 8. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Prepare added time ids & embeddings
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 11. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if i <= start_merge_step:
+ current_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, prompt_embeds_text_only], dim=0
+ )
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
+ else:
+ current_prompt_embeds = torch.cat(
+ [negative_prompt_embeds, prompt_embeds], dim=0
+ )
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=current_prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ # if self.watermark is not None:
+ # image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
\ No newline at end of file
diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0509ee188f34e19a07546aa0dfd0606ff438426
--- /dev/null
+++ b/toolkit/pipelines.py
@@ -0,0 +1,1421 @@
+import importlib
+import inspect
+from typing import Union, List, Optional, Dict, Any, Tuple, Callable
+
+import numpy as np
+import torch
+from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LMSDiscreteScheduler, FluxPipeline
+from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+# from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_k_diffusion import ModelWrapper
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+from diffusers.utils import is_torch_xla_available
+from k_diffusion.external import CompVisVDenoiser, CompVisDenoiser
+from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
+from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+class StableDiffusionKDiffusionXLPipeline(StableDiffusionXLPipeline):
+
+ def __init__(
+ self,
+ vae: 'AutoencoderKL',
+ text_encoder: 'CLIPTextModel',
+ text_encoder_2: 'CLIPTextModelWithProjection',
+ tokenizer: 'CLIPTokenizer',
+ tokenizer_2: 'CLIPTokenizer',
+ unet: 'UNet2DConditionModel',
+ scheduler: 'KarrasDiffusionSchedulers',
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ raise NotImplementedError("This pipeline is not implemented yet")
+ # self.sampler = None
+ # scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
+ # model = ModelWrapper(unet, scheduler.alphas_cumprod)
+ # if scheduler.config.prediction_type == "v_prediction":
+ # self.k_diffusion_model = CompVisVDenoiser(model)
+ # else:
+ # self.k_diffusion_model = CompVisDenoiser(model)
+
+ def set_scheduler(self, scheduler_type: str):
+ library = importlib.import_module("k_diffusion")
+ sampling = getattr(library, "sampling")
+ self.sampler = getattr(sampling, scheduler_type)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ use_karras_sigmas: bool = False,
+ ):
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 5. Prepare sigmas
+ if use_karras_sigmas:
+ sigma_min: float = self.k_diffusion_model.sigmas[0].item()
+ sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
+ sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
+ sigmas = sigmas.to(device)
+ else:
+ sigmas = self.scheduler.sigmas
+ sigmas = sigmas.to(prompt_embeds.dtype)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ latents = latents * sigmas[0]
+ self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
+ self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
+
+ # 7. Define model function
+ def model_fn(x, t):
+ latent_model_input = torch.cat([x] * 2)
+ t = torch.cat([t] * 2)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ # noise_pred = self.unet(
+ # latent_model_input,
+ # t,
+ # encoder_hidden_states=prompt_embeds,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ # added_cond_kwargs=added_cond_kwargs,
+ # return_dict=False,
+ # )[0]
+
+ noise_pred = self.k_diffusion_model(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,)[0]
+
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ return noise_pred
+
+
+ # 8. Run k-diffusion solver
+ sampler_kwargs = {}
+ # should work without it
+ noise_sampler_seed = None
+
+
+ if "noise_sampler" in inspect.signature(self.sampler).parameters:
+ min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
+ sampler_kwargs["noise_sampler"] = noise_sampler
+
+ latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+
+class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
+
+ def predict_noise(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ timestep: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # if not predict_noise:
+ # # call parent
+ # return super().__call__(
+ # prompt=prompt,
+ # prompt_2=prompt_2,
+ # height=height,
+ # width=width,
+ # num_inference_steps=num_inference_steps,
+ # denoising_end=denoising_end,
+ # guidance_scale=guidance_scale,
+ # negative_prompt=negative_prompt,
+ # negative_prompt_2=negative_prompt_2,
+ # num_images_per_prompt=num_images_per_prompt,
+ # eta=eta,
+ # generator=generator,
+ # latents=latents,
+ # prompt_embeds=prompt_embeds,
+ # negative_prompt_embeds=negative_prompt_embeds,
+ # pooled_prompt_embeds=pooled_prompt_embeds,
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ # output_type=output_type,
+ # return_dict=return_dict,
+ # callback=callback,
+ # callback_steps=callback_steps,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ # guidance_rescale=guidance_rescale,
+ # original_size=original_size,
+ # crops_coords_top_left=crops_coords_top_left,
+ # target_size=target_size,
+ # )
+
+ # 0. Default height and width to unet
+ height = self.default_sample_size * self.vae_scale_factor
+ width = self.default_sample_size * self.vae_scale_factor
+
+ original_size = (height, width)
+ target_size = (height, width)
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ ).to(device) # TODO DOES NOT CAST ORIGINALLY
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ return noise_pred
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ print('Called cpu offload', gpu_id)
+ # fuck off
+ pass
+
+
+class CustomStableDiffusionPipeline(StableDiffusionPipeline):
+
+ # replace the call so it matches SDXL call so we can use the same code and also stop early
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ ):
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ # some of the inputs are to keep it compatible with sdx
+ def predict_noise(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ timestep: Optional[int] = None,
+ ):
+
+ # 0. Default height and width to unet
+ height = self.unet.config.sample_size * self.vae_scale_factor
+ width = self.unet.config.sample_size * self.vae_scale_factor
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ return noise_pred
+
+
+class StableDiffusionXLRefinerPipeline(StableDiffusionXLPipeline):
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ denoising_start: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
+ Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 8.2 Determine denoising_start
+ denoising_start_index = 0
+ if denoising_start is not None and isinstance(denoising_start, float) and denoising_start > 0 and denoising_start < 1:
+ discrete_timestep_start = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ denoising_start_index = len(list(filter(lambda ts: ts < discrete_timestep_start, timesteps)))
+
+
+ with self.progress_bar(total=num_inference_steps - denoising_start_index) as progress_bar:
+ for i, t in enumerate(timesteps, start=denoising_start_index):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+
+
+
+# TODO this is rough. Need to properly stack unconditional
+class FluxWithCFGPipeline(FluxPipeline):
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ # bypass the guidance embedding if there is one
+ bypass_flux_guidance(self.transformer)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.tensor([guidance_scale], device=device)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ noise_pred_text = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # todo combine these
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+ restore_flux_guidance(self.transformer)
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
\ No newline at end of file
diff --git a/toolkit/progress_bar.py b/toolkit/progress_bar.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42f8086a7d29016beea66b09e8c0fdc574c5422
--- /dev/null
+++ b/toolkit/progress_bar.py
@@ -0,0 +1,25 @@
+from tqdm import tqdm
+import time
+
+
+class ToolkitProgressBar(tqdm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.paused = False
+ self.last_time = self._time()
+
+ def pause(self):
+ if not self.paused:
+ self.paused = True
+ self.last_time = self._time()
+
+ def unpause(self):
+ if self.paused:
+ self.paused = False
+ cur_t = self._time()
+ self.start_t += cur_t - self.last_time
+ self.last_print_t = cur_t
+
+ def update(self, *args, **kwargs):
+ if not self.paused:
+ super().update(*args, **kwargs)
diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a145841cc75e9fd96d9278d4eb8d78db262c7c8d
--- /dev/null
+++ b/toolkit/prompt_utils.py
@@ -0,0 +1,561 @@
+import os
+from typing import Optional, TYPE_CHECKING, List, Union, Tuple
+
+import torch
+from safetensors.torch import load_file, save_file
+from tqdm import tqdm
+import random
+
+from toolkit.train_tools import get_torch_dtype
+import itertools
+
+if TYPE_CHECKING:
+ from toolkit.config_modules import SliderTargetConfig
+
+
+class ACTION_TYPES_SLIDER:
+ ERASE_NEGATIVE = 0
+ ENHANCE_NEGATIVE = 1
+
+
+class PromptEmbeds:
+ # text_embeds: torch.Tensor
+ # pooled_embeds: Union[torch.Tensor, None]
+ # attention_mask: Union[torch.Tensor, None]
+
+ def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
+ if isinstance(args, list) or isinstance(args, tuple):
+ # xl
+ self.text_embeds = args[0]
+ self.pooled_embeds = args[1]
+ else:
+ # sdv1.x, sdv2.x
+ self.text_embeds = args
+ self.pooled_embeds = None
+
+ self.attention_mask = attention_mask
+
+ def to(self, *args, **kwargs):
+ self.text_embeds = self.text_embeds.to(*args, **kwargs)
+ if self.pooled_embeds is not None:
+ self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
+ if self.attention_mask is not None:
+ self.attention_mask = self.attention_mask.to(*args, **kwargs)
+ return self
+
+ def detach(self):
+ new_embeds = self.clone()
+ new_embeds.text_embeds = new_embeds.text_embeds.detach()
+ if new_embeds.pooled_embeds is not None:
+ new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
+ if new_embeds.attention_mask is not None:
+ new_embeds.attention_mask = new_embeds.attention_mask.detach()
+ return new_embeds
+
+ def clone(self):
+ if self.pooled_embeds is not None:
+ prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
+ else:
+ prompt_embeds = PromptEmbeds(self.text_embeds.clone())
+
+ if self.attention_mask is not None:
+ prompt_embeds.attention_mask = self.attention_mask.clone()
+ return prompt_embeds
+
+
+class EncodedPromptPair:
+ def __init__(
+ self,
+ target_class,
+ target_class_with_neutral,
+ positive_target,
+ positive_target_with_neutral,
+ negative_target,
+ negative_target_with_neutral,
+ neutral,
+ empty_prompt,
+ both_targets,
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ action_list=None,
+ multiplier=1.0,
+ multiplier_list=None,
+ weight=1.0,
+ target: 'SliderTargetConfig' = None,
+ ):
+ self.target_class: PromptEmbeds = target_class
+ self.target_class_with_neutral: PromptEmbeds = target_class_with_neutral
+ self.positive_target: PromptEmbeds = positive_target
+ self.positive_target_with_neutral: PromptEmbeds = positive_target_with_neutral
+ self.negative_target: PromptEmbeds = negative_target
+ self.negative_target_with_neutral: PromptEmbeds = negative_target_with_neutral
+ self.neutral: PromptEmbeds = neutral
+ self.empty_prompt: PromptEmbeds = empty_prompt
+ self.both_targets: PromptEmbeds = both_targets
+ self.multiplier: float = multiplier
+ self.target: 'SliderTargetConfig' = target
+ if multiplier_list is not None:
+ self.multiplier_list: list[float] = multiplier_list
+ else:
+ self.multiplier_list: list[float] = [multiplier]
+ self.action: int = action
+ if action_list is not None:
+ self.action_list: list[int] = action_list
+ else:
+ self.action_list: list[int] = [action]
+ self.weight: float = weight
+
+ # simulate torch to for tensors
+ def to(self, *args, **kwargs):
+ self.target_class = self.target_class.to(*args, **kwargs)
+ self.target_class_with_neutral = self.target_class_with_neutral.to(*args, **kwargs)
+ self.positive_target = self.positive_target.to(*args, **kwargs)
+ self.positive_target_with_neutral = self.positive_target_with_neutral.to(*args, **kwargs)
+ self.negative_target = self.negative_target.to(*args, **kwargs)
+ self.negative_target_with_neutral = self.negative_target_with_neutral.to(*args, **kwargs)
+ self.neutral = self.neutral.to(*args, **kwargs)
+ self.empty_prompt = self.empty_prompt.to(*args, **kwargs)
+ self.both_targets = self.both_targets.to(*args, **kwargs)
+ return self
+
+ def detach(self):
+ self.target_class = self.target_class.detach()
+ self.target_class_with_neutral = self.target_class_with_neutral.detach()
+ self.positive_target = self.positive_target.detach()
+ self.positive_target_with_neutral = self.positive_target_with_neutral.detach()
+ self.negative_target = self.negative_target.detach()
+ self.negative_target_with_neutral = self.negative_target_with_neutral.detach()
+ self.neutral = self.neutral.detach()
+ self.empty_prompt = self.empty_prompt.detach()
+ self.both_targets = self.both_targets.detach()
+ return self
+
+
+def concat_prompt_embeds(prompt_embeds: list[PromptEmbeds]):
+ text_embeds = torch.cat([p.text_embeds for p in prompt_embeds], dim=0)
+ pooled_embeds = None
+ if prompt_embeds[0].pooled_embeds is not None:
+ pooled_embeds = torch.cat([p.pooled_embeds for p in prompt_embeds], dim=0)
+ return PromptEmbeds([text_embeds, pooled_embeds])
+
+
+def concat_prompt_pairs(prompt_pairs: list[EncodedPromptPair]):
+ weight = prompt_pairs[0].weight
+ target_class = concat_prompt_embeds([p.target_class for p in prompt_pairs])
+ target_class_with_neutral = concat_prompt_embeds([p.target_class_with_neutral for p in prompt_pairs])
+ positive_target = concat_prompt_embeds([p.positive_target for p in prompt_pairs])
+ positive_target_with_neutral = concat_prompt_embeds([p.positive_target_with_neutral for p in prompt_pairs])
+ negative_target = concat_prompt_embeds([p.negative_target for p in prompt_pairs])
+ negative_target_with_neutral = concat_prompt_embeds([p.negative_target_with_neutral for p in prompt_pairs])
+ neutral = concat_prompt_embeds([p.neutral for p in prompt_pairs])
+ empty_prompt = concat_prompt_embeds([p.empty_prompt for p in prompt_pairs])
+ both_targets = concat_prompt_embeds([p.both_targets for p in prompt_pairs])
+ # combine all the lists
+ action_list = []
+ multiplier_list = []
+ weight_list = []
+ for p in prompt_pairs:
+ action_list += p.action_list
+ multiplier_list += p.multiplier_list
+ return EncodedPromptPair(
+ target_class=target_class,
+ target_class_with_neutral=target_class_with_neutral,
+ positive_target=positive_target,
+ positive_target_with_neutral=positive_target_with_neutral,
+ negative_target=negative_target,
+ negative_target_with_neutral=negative_target_with_neutral,
+ neutral=neutral,
+ empty_prompt=empty_prompt,
+ both_targets=both_targets,
+ action_list=action_list,
+ multiplier_list=multiplier_list,
+ weight=weight,
+ target=prompt_pairs[0].target
+ )
+
+
+def split_prompt_embeds(concatenated: PromptEmbeds, num_parts=None) -> List[PromptEmbeds]:
+ if num_parts is None:
+ # use batch size
+ num_parts = concatenated.text_embeds.shape[0]
+ text_embeds_splits = torch.chunk(concatenated.text_embeds, num_parts, dim=0)
+
+ if concatenated.pooled_embeds is not None:
+ pooled_embeds_splits = torch.chunk(concatenated.pooled_embeds, num_parts, dim=0)
+ else:
+ pooled_embeds_splits = [None] * num_parts
+
+ prompt_embeds_list = [
+ PromptEmbeds([text, pooled])
+ for text, pooled in zip(text_embeds_splits, pooled_embeds_splits)
+ ]
+
+ return prompt_embeds_list
+
+
+def split_prompt_pairs(concatenated: EncodedPromptPair, num_embeds=None) -> List[EncodedPromptPair]:
+ target_class_splits = split_prompt_embeds(concatenated.target_class, num_embeds)
+ target_class_with_neutral_splits = split_prompt_embeds(concatenated.target_class_with_neutral, num_embeds)
+ positive_target_splits = split_prompt_embeds(concatenated.positive_target, num_embeds)
+ positive_target_with_neutral_splits = split_prompt_embeds(concatenated.positive_target_with_neutral, num_embeds)
+ negative_target_splits = split_prompt_embeds(concatenated.negative_target, num_embeds)
+ negative_target_with_neutral_splits = split_prompt_embeds(concatenated.negative_target_with_neutral, num_embeds)
+ neutral_splits = split_prompt_embeds(concatenated.neutral, num_embeds)
+ empty_prompt_splits = split_prompt_embeds(concatenated.empty_prompt, num_embeds)
+ both_targets_splits = split_prompt_embeds(concatenated.both_targets, num_embeds)
+
+ prompt_pairs = []
+ for i in range(len(target_class_splits)):
+ action_list_split = concatenated.action_list[i::len(target_class_splits)]
+ multiplier_list_split = concatenated.multiplier_list[i::len(target_class_splits)]
+
+ prompt_pair = EncodedPromptPair(
+ target_class=target_class_splits[i],
+ target_class_with_neutral=target_class_with_neutral_splits[i],
+ positive_target=positive_target_splits[i],
+ positive_target_with_neutral=positive_target_with_neutral_splits[i],
+ negative_target=negative_target_splits[i],
+ negative_target_with_neutral=negative_target_with_neutral_splits[i],
+ neutral=neutral_splits[i],
+ empty_prompt=empty_prompt_splits[i],
+ both_targets=both_targets_splits[i],
+ action_list=action_list_split,
+ multiplier_list=multiplier_list_split,
+ weight=concatenated.weight,
+ target=concatenated.target
+ )
+ prompt_pairs.append(prompt_pair)
+
+ return prompt_pairs
+
+
+class PromptEmbedsCache:
+ prompts: dict[str, PromptEmbeds] = {}
+
+ def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
+ self.prompts[__name] = __value
+
+ def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
+ if __name in self.prompts:
+ return self.prompts[__name]
+ else:
+ return None
+
+
+class EncodedAnchor:
+ def __init__(
+ self,
+ prompt,
+ neg_prompt,
+ multiplier=1.0,
+ multiplier_list=None
+ ):
+ self.prompt = prompt
+ self.neg_prompt = neg_prompt
+ self.multiplier = multiplier
+
+ if multiplier_list is not None:
+ self.multiplier_list: list[float] = multiplier_list
+ else:
+ self.multiplier_list: list[float] = [multiplier]
+
+ def to(self, *args, **kwargs):
+ self.prompt = self.prompt.to(*args, **kwargs)
+ self.neg_prompt = self.neg_prompt.to(*args, **kwargs)
+ return self
+
+
+def concat_anchors(anchors: list[EncodedAnchor]):
+ prompt = concat_prompt_embeds([a.prompt for a in anchors])
+ neg_prompt = concat_prompt_embeds([a.neg_prompt for a in anchors])
+ return EncodedAnchor(
+ prompt=prompt,
+ neg_prompt=neg_prompt,
+ multiplier_list=[a.multiplier for a in anchors]
+ )
+
+
+def split_anchors(concatenated: EncodedAnchor, num_anchors: int = 4) -> List[EncodedAnchor]:
+ prompt_splits = split_prompt_embeds(concatenated.prompt, num_anchors)
+ neg_prompt_splits = split_prompt_embeds(concatenated.neg_prompt, num_anchors)
+ multiplier_list_splits = torch.chunk(torch.tensor(concatenated.multiplier_list), num_anchors)
+
+ anchors = []
+ for prompt, neg_prompt, multiplier in zip(prompt_splits, neg_prompt_splits, multiplier_list_splits):
+ anchor = EncodedAnchor(
+ prompt=prompt,
+ neg_prompt=neg_prompt,
+ multiplier=multiplier.tolist()
+ )
+ anchors.append(anchor)
+
+ return anchors
+
+
+def get_permutations(s, max_permutations=8):
+ # Split the string by comma
+ phrases = [phrase.strip() for phrase in s.split(',')]
+
+ # remove empty strings
+ phrases = [phrase for phrase in phrases if len(phrase) > 0]
+ # shuffle the list
+ random.shuffle(phrases)
+
+ # Get all permutations
+ permutations = list([p for p in itertools.islice(itertools.permutations(phrases), max_permutations)])
+
+ # Convert the tuples back to comma separated strings
+ return [', '.join(permutation) for permutation in permutations]
+
+
+def get_slider_target_permutations(target: 'SliderTargetConfig', max_permutations=8) -> List['SliderTargetConfig']:
+ from toolkit.config_modules import SliderTargetConfig
+ pos_permutations = get_permutations(target.positive, max_permutations=max_permutations)
+ neg_permutations = get_permutations(target.negative, max_permutations=max_permutations)
+
+ permutations = []
+ for pos, neg in itertools.product(pos_permutations, neg_permutations):
+ permutations.append(
+ SliderTargetConfig(
+ target_class=target.target_class,
+ positive=pos,
+ negative=neg,
+ multiplier=target.multiplier,
+ weight=target.weight
+ )
+ )
+
+ # shuffle the list
+ random.shuffle(permutations)
+
+ if len(permutations) > max_permutations:
+ permutations = permutations[:max_permutations]
+
+ return permutations
+
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+@torch.no_grad()
+def encode_prompts_to_cache(
+ prompt_list: list[str],
+ sd: "StableDiffusion",
+ cache: Optional[PromptEmbedsCache] = None,
+ prompt_tensor_file: Optional[str] = None,
+) -> PromptEmbedsCache:
+ # TODO: add support for larger prompts
+ if cache is None:
+ cache = PromptEmbedsCache()
+
+ if prompt_tensor_file is not None:
+ # check to see if it exists
+ if os.path.exists(prompt_tensor_file):
+ # load it.
+ print(f"Loading prompt tensors from {prompt_tensor_file}")
+ prompt_tensors = load_file(prompt_tensor_file, device='cpu')
+ # add them to the cache
+ for prompt_txt, prompt_tensor in tqdm(prompt_tensors.items(), desc="Loading prompts", leave=False):
+ if prompt_txt.startswith("te:"):
+ prompt = prompt_txt[3:]
+ # text_embeds
+ text_embeds = prompt_tensor
+ pooled_embeds = None
+ # find pool embeds
+ if f"pe:{prompt}" in prompt_tensors:
+ pooled_embeds = prompt_tensors[f"pe:{prompt}"]
+
+ # make it
+ prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
+ cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
+
+ if len(cache.prompts) == 0:
+ print("Prompt tensors not found. Encoding prompts..")
+ empty_prompt = ""
+ # encode empty_prompt
+ cache[empty_prompt] = sd.encode_prompt(empty_prompt)
+
+ for p in tqdm(prompt_list, desc="Encoding prompts", leave=False):
+ # build the cache
+ if cache[p] is None:
+ cache[p] = sd.encode_prompt(p).to(device="cpu", dtype=torch.float16)
+
+ # should we shard? It can get large
+ if prompt_tensor_file:
+ print(f"Saving prompt tensors to {prompt_tensor_file}")
+ state_dict = {}
+ for prompt_txt, prompt_embeds in cache.prompts.items():
+ state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to(
+ "cpu", dtype=get_torch_dtype('fp16')
+ )
+ if prompt_embeds.pooled_embeds is not None:
+ state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to(
+ "cpu",
+ dtype=get_torch_dtype('fp16')
+ )
+ save_file(state_dict, prompt_tensor_file)
+
+ return cache
+
+
+@torch.no_grad()
+def build_prompt_pair_batch_from_cache(
+ cache: PromptEmbedsCache,
+ target: 'SliderTargetConfig',
+ neutral: Optional[str] = '',
+) -> list[EncodedPromptPair]:
+ erase_negative = len(target.positive.strip()) == 0
+ enhance_positive = len(target.negative.strip()) == 0
+
+ both = not erase_negative and not enhance_positive
+
+ prompt_pair_batch = []
+
+ if both or erase_negative:
+ # print("Encoding erase negative")
+ prompt_pair_batch += [
+ # erase standard
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
+ positive_target=cache[f"{target.positive}"],
+ positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
+ negative_target=cache[f"{target.negative}"],
+ negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
+ neutral=cache[neutral],
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ multiplier=target.multiplier,
+ both_targets=cache[f"{target.positive} {target.negative}"],
+ empty_prompt=cache[""],
+ weight=target.weight,
+ target=target
+ ),
+ ]
+ if both or enhance_positive:
+ # print("Encoding enhance positive")
+ prompt_pair_batch += [
+ # enhance standard, swap pos neg
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
+ positive_target=cache[f"{target.negative}"],
+ positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
+ negative_target=cache[f"{target.positive}"],
+ negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
+ neutral=cache[neutral],
+ action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
+ multiplier=target.multiplier,
+ both_targets=cache[f"{target.positive} {target.negative}"],
+ empty_prompt=cache[""],
+ weight=target.weight,
+ target=target
+ ),
+ ]
+ if both or enhance_positive:
+ # print("Encoding erase positive (inverse)")
+ prompt_pair_batch += [
+ # erase inverted
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
+ positive_target=cache[f"{target.negative}"],
+ positive_target_with_neutral=cache[f"{target.negative} {neutral}"],
+ negative_target=cache[f"{target.positive}"],
+ negative_target_with_neutral=cache[f"{target.positive} {neutral}"],
+ neutral=cache[neutral],
+ action=ACTION_TYPES_SLIDER.ERASE_NEGATIVE,
+ both_targets=cache[f"{target.positive} {target.negative}"],
+ empty_prompt=cache[""],
+ multiplier=target.multiplier * -1.0,
+ weight=target.weight,
+ target=target
+ ),
+ ]
+ if both or erase_negative:
+ # print("Encoding enhance negative (inverse)")
+ prompt_pair_batch += [
+ # enhance inverted
+ EncodedPromptPair(
+ target_class=cache[target.target_class],
+ target_class_with_neutral=cache[f"{target.target_class} {neutral}"],
+ positive_target=cache[f"{target.positive}"],
+ positive_target_with_neutral=cache[f"{target.positive} {neutral}"],
+ negative_target=cache[f"{target.negative}"],
+ negative_target_with_neutral=cache[f"{target.negative} {neutral}"],
+ both_targets=cache[f"{target.positive} {target.negative}"],
+ neutral=cache[neutral],
+ action=ACTION_TYPES_SLIDER.ENHANCE_NEGATIVE,
+ empty_prompt=cache[""],
+ multiplier=target.multiplier * -1.0,
+ weight=target.weight,
+ target=target
+ ),
+ ]
+
+ return prompt_pair_batch
+
+
+def build_latent_image_batch_for_prompt_pair(
+ pos_latent,
+ neg_latent,
+ prompt_pair: EncodedPromptPair,
+ prompt_chunk_size
+):
+ erase_negative = len(prompt_pair.target.positive.strip()) == 0
+ enhance_positive = len(prompt_pair.target.negative.strip()) == 0
+ both = not erase_negative and not enhance_positive
+
+ prompt_pair_chunks = split_prompt_pairs(prompt_pair, prompt_chunk_size)
+ if both and len(prompt_pair_chunks) != 4:
+ raise Exception("Invalid prompt pair chunks")
+ if (erase_negative or enhance_positive) and len(prompt_pair_chunks) != 2:
+ raise Exception("Invalid prompt pair chunks")
+
+ latent_list = []
+
+ if both or erase_negative:
+ latent_list.append(pos_latent)
+ if both or enhance_positive:
+ latent_list.append(pos_latent)
+ if both or enhance_positive:
+ latent_list.append(neg_latent)
+ if both or erase_negative:
+ latent_list.append(neg_latent)
+
+ return torch.cat(latent_list, dim=0)
+
+
+def inject_trigger_into_prompt(prompt, trigger=None, to_replace_list=None, add_if_not_present=True):
+ if trigger is None:
+ # process as empty string to remove any [trigger] tokens
+ trigger = ''
+ output_prompt = prompt
+ default_replacements = ["[name]", "[trigger]"]
+
+ replace_with = trigger
+ if to_replace_list is None:
+ to_replace_list = default_replacements
+ else:
+ to_replace_list += default_replacements
+
+ # remove duplicates
+ to_replace_list = list(set(to_replace_list))
+
+ # replace them all
+ for to_replace in to_replace_list:
+ # replace it
+ output_prompt = output_prompt.replace(to_replace, replace_with)
+
+ if trigger.strip() != "":
+ # see how many times replace_with is in the prompt
+ num_instances = output_prompt.count(replace_with)
+
+ if num_instances == 0 and add_if_not_present:
+ # add it to the beginning of the prompt
+ output_prompt = replace_with + " " + output_prompt
+
+ # if num_instances > 1:
+ # print(
+ # f"Warning: {trigger} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
+
+ return output_prompt
diff --git a/toolkit/reference_adapter.py b/toolkit/reference_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00dfb72974d03917723a4ef54caee7f32dbcdd1
--- /dev/null
+++ b/toolkit/reference_adapter.py
@@ -0,0 +1,410 @@
+import math
+
+import torch
+import sys
+
+from PIL import Image
+from torch.nn import Parameter
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from toolkit.basic import adain
+from toolkit.paths import REPOS_ROOT
+from toolkit.saving import load_ip_adapter_model
+from toolkit.train_tools import get_torch_dtype
+
+sys.path.append(REPOS_ROOT)
+from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
+from collections import OrderedDict
+from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
+ AttnProcessor2_0
+from ipadapter.ip_adapter.ip_adapter import ImageProjModel
+from ipadapter.ip_adapter.resampler import Resampler
+from toolkit.config_modules import AdapterConfig
+from toolkit.prompt_utils import PromptEmbeds
+import weakref
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+from diffusers import (
+ EulerDiscreteScheduler,
+ DDPMScheduler,
+)
+
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection
+)
+from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
+
+from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
+
+from transformers import ViTFeatureExtractor, ViTForImageClassification
+
+import torch.nn.functional as F
+import torch.nn as nn
+
+
+class ReferenceAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapater for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.ref_net = nn.Linear(hidden_size, hidden_size)
+ self.blend = nn.Parameter(torch.zeros(hidden_size))
+ self.adapter_ref: weakref.ref = weakref.ref(adapter)
+ self._memory = None
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ if self.adapter_ref().is_active:
+ if self.adapter_ref().reference_mode == "write":
+ # write_mode
+ memory_ref = self.ref_net(hidden_states)
+ self._memory = memory_ref
+ elif self.adapter_ref().reference_mode == "read":
+ # read_mode
+ if self._memory is None:
+ print("Warning: no memory to read from")
+ else:
+
+ saved_hidden_states = self._memory
+ try:
+ new_hidden_states = saved_hidden_states
+ blend = self.blend
+ # expand the blend buyt keep dim 0 the same (batch)
+ while blend.ndim < new_hidden_states.ndim:
+ blend = blend.unsqueeze(0)
+ # expand batch
+ blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0)
+ hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states
+ except Exception as e:
+ raise Exception(f"Error blending: {e}")
+
+ return hidden_states
+
+
+class ReferenceAdapter(torch.nn.Module):
+
+ def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
+ super().__init__()
+ self.config = adapter_config
+ self.sd_ref: weakref.ref = weakref.ref(sd)
+ self.device = self.sd_ref().unet.device
+ self.reference_mode = "read"
+ self.current_scale = 1.0
+ self.is_active = True
+ self._reference_images = None
+ self._reference_latents = None
+ self.has_memory = False
+
+ self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None
+
+ # init adapter modules
+ attn_procs = {}
+ unet_sd = sd.unet.state_dict()
+ for name in sd.unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
+ if name.startswith("mid_block"):
+ hidden_size = sd.unet.config['block_out_channels'][-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = sd.unet.config['block_out_channels'][block_id]
+ else:
+ # they didnt have this, but would lead to undefined below
+ raise ValueError(f"unknown attn processor name: {name}")
+ if cross_attention_dim is None:
+ attn_procs[name] = AttnProcessor2_0()
+ else:
+ # layer_name = name.split(".processor")[0]
+ # weights = {
+ # "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ # "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ # }
+
+ attn_procs[name] = ReferenceAttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=self.config.num_tokens,
+ adapter=self
+ )
+ # attn_procs[name].load_state_dict(weights)
+ sd.unet.set_attn_processor(attn_procs)
+ adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
+
+ sd.adapter = self
+ self.unet_ref: weakref.ref = weakref.ref(sd.unet)
+ self.adapter_modules = adapter_modules
+ # load the weights if we have some
+ if self.config.name_or_path:
+ loaded_state_dict = load_ip_adapter_model(
+ self.config.name_or_path,
+ device='cpu',
+ dtype=sd.torch_dtype
+ )
+ self.load_state_dict(loaded_state_dict)
+
+ self.set_scale(1.0)
+ self.attach()
+ self.to(self.device, self.sd_ref().torch_dtype)
+
+ # if self.config.train_image_encoder:
+ # self.image_encoder.train()
+ # self.image_encoder.requires_grad_(True)
+
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ # self.image_encoder.to(*args, **kwargs)
+ # self.image_proj_model.to(*args, **kwargs)
+ self.adapter_modules.to(*args, **kwargs)
+ return self
+
+ def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]):
+ reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
+ reference_layers.load_state_dict(state_dict["reference_adapter"])
+
+ # def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
+ # self.load_ip_adapter(state_dict)
+
+ def state_dict(self) -> OrderedDict:
+ state_dict = OrderedDict()
+ state_dict["reference_adapter"] = self.adapter_modules.state_dict()
+ return state_dict
+
+ def get_scale(self):
+ return self.current_scale
+
+ def set_reference_images(self, reference_images: Optional[torch.Tensor]):
+ self._reference_images = reference_images.clone().detach()
+ self._reference_latents = None
+ self.clear_memory()
+
+ def set_blank_reference_images(self, batch_size):
+ self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype)
+ self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype)
+ self.clear_memory()
+
+
+ def set_scale(self, scale):
+ self.current_scale = scale
+ for attn_processor in self.sd_ref().unet.attn_processors.values():
+ if isinstance(attn_processor, ReferenceAttnProcessor2_0):
+ attn_processor.scale = scale
+
+
+ def attach(self):
+ unet = self.sd_ref().unet
+ self._original_unet_forward = unet.forward
+ unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs)
+ if self.sd_ref().network is not None:
+ # set network to not merge in
+ self.sd_ref().network.can_merge_in = False
+
+ def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
+ skip = False
+ if self._reference_images is None and self._reference_latents is None:
+ skip = True
+ if not self.is_active:
+ skip = True
+
+ if self.has_memory:
+ skip = True
+
+ if not skip:
+ if self.sd_ref().network is not None:
+ self.sd_ref().network.is_active = True
+ if self.sd_ref().network.is_merged_in:
+ raise ValueError("network is merged in, but we are not supposed to be merged in")
+ # send it through our forward first
+ self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs)
+
+ if self.sd_ref().network is not None:
+ self.sd_ref().network.is_active = False
+
+ # Send it through the original unet forward
+ return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs)
+
+
+ # use drop for prompt dropout, or negatives
+ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
+ if not self.noise_scheduler:
+ raise ValueError("noise scheduler not set")
+ if not self.is_active or (self._reference_images is None and self._reference_latents is None):
+ raise ValueError("reference adapter not active or no reference images set")
+ # todo may need to handle cfg?
+ self.reference_mode = "write"
+
+ if self._reference_latents is None:
+ self._reference_latents = self.sd_ref().encode_images(self._reference_images.to(
+ self.device, self.sd_ref().torch_dtype
+ )).detach()
+ # create a sample from our reference images
+ reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype)
+ # if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional)
+ if reference_latents.shape[0] * 2 == sample.shape[0]:
+ # we are doing cfg
+ # Unconditional goes first
+ reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach()
+
+ # resize it so reference_latents will fit inside sample in the center
+ width_scale = sample.shape[2] / reference_latents.shape[2]
+ height_scale = sample.shape[3] / reference_latents.shape[3]
+ scale = min(width_scale, height_scale)
+ # resize the reference latents
+
+ mode = "bilinear" if scale > 1.0 else "bicubic"
+
+ reference_latents = F.interpolate(
+ reference_latents,
+ size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)),
+ mode=mode,
+ align_corners=False
+ )
+
+ # add 0 padding if needed
+ width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2
+ height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2
+ reference_latents = F.pad(
+ reference_latents,
+ (math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)),
+ mode="constant",
+ value=0
+ )
+
+ # resize again just to make sure it is exact same size
+ reference_latents = F.interpolate(
+ reference_latents,
+ size=(sample.shape[2], sample.shape[3]),
+ mode="bicubic",
+ align_corners=False
+ )
+
+ # todo maybe add same noise to the sample? For now we will send it through with no noise
+ # sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep)
+ self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs)
+ self.reference_mode = "read"
+ self.has_memory = True
+ return None
+
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
+ for attn_processor in self.adapter_modules:
+ yield from attn_processor.parameters(recurse)
+ # yield from self.image_proj_model.parameters(recurse)
+ # if self.config.train_image_encoder:
+ # yield from self.image_encoder.parameters(recurse)
+ # if self.config.train_image_encoder:
+ # yield from self.image_encoder.parameters(recurse)
+ # self.image_encoder.train()
+ # else:
+ # for attn_processor in self.adapter_modules:
+ # yield from attn_processor.parameters(recurse)
+ # yield from self.image_proj_model.parameters(recurse)
+
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
+ strict = False
+ # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
+ self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict)
+
+ def enable_gradient_checkpointing(self):
+ self.image_encoder.gradient_checkpointing = True
+
+ def clear_memory(self):
+ for attn_processor in self.adapter_modules:
+ if isinstance(attn_processor, ReferenceAttnProcessor2_0):
+ attn_processor._memory = None
+ self.has_memory = False
diff --git a/toolkit/resampler.py b/toolkit/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ace5a3a18d78e6f5b712dd587aaa45827247dc6
--- /dev/null
+++ b/toolkit/resampler.py
@@ -0,0 +1,160 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ max_seq_len: int = 257, # CLIP tokens + CLS token
+ apply_pos_emb: bool = False,
+ num_latents_mean_pooled: int = 0,
+ # number of latents derived from mean pooled representation of the sequence
+ ):
+ super().__init__()
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.to_latents_from_mean_pooled_seq = (
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * num_latents_mean_pooled),
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+ )
+ if num_latents_mean_pooled > 0
+ else None
+ )
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ if self.pos_emb is not None:
+ n, device = x.shape[1], x.device
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
+ x = x + pos_emb
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ if self.to_latents_from_mean_pooled_seq:
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+def masked_mean(t, *, dim, mask=None):
+ if mask is None:
+ return t.mean(dim=dim)
+
+ denom = mask.sum(dim=dim, keepdim=True)
+ mask = rearrange(mask, "b n -> b n 1")
+ masked_t = t.masked_fill(~mask, 0.0)
+
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
diff --git a/toolkit/sampler.py b/toolkit/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b0311b3b4c6b788bcf18c78df2c86a134465c0
--- /dev/null
+++ b/toolkit/sampler.py
@@ -0,0 +1,164 @@
+import copy
+import math
+
+from diffusers import (
+ DDPMScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ DPMSolverSinglestepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ DDIMScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ KDPM2DiscreteScheduler,
+ KDPM2AncestralDiscreteScheduler,
+ LCMScheduler,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
+
+from k_diffusion.external import CompVisDenoiser
+
+from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
+
+# scheduler:
+SCHEDULER_LINEAR_START = 0.00085
+SCHEDULER_LINEAR_END = 0.0120
+SCHEDULER_TIMESTEPS = 1000
+SCHEDLER_SCHEDULE = "scaled_linear"
+
+sd_config = {
+ "_class_name": "EulerAncestralDiscreteScheduler",
+ "_diffusers_version": "0.24.0.dev0",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "interpolation_type": "linear",
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "set_alpha_to_one": False,
+ # "skip_prk_steps": False, # for training
+ "skip_prk_steps": True,
+ # "steps_offset": 1,
+ "steps_offset": 0,
+ # "timestep_spacing": "trailing", # for training
+ "timestep_spacing": "leading",
+ "trained_betas": None
+}
+
+pixart_config = {
+ "_class_name": "DPMSolverMultistepScheduler",
+ "_diffusers_version": "0.22.0.dev0",
+ "algorithm_type": "dpmsolver++",
+ "beta_end": 0.02,
+ "beta_schedule": "linear",
+ "beta_start": 0.0001,
+ "dynamic_thresholding_ratio": 0.995,
+ "euler_at_final": False,
+ # "lambda_min_clipped": -Infinity,
+ "lambda_min_clipped": -math.inf,
+ "lower_order_final": True,
+ "num_train_timesteps": 1000,
+ "prediction_type": "epsilon",
+ "sample_max_value": 1.0,
+ "solver_order": 2,
+ "solver_type": "midpoint",
+ "steps_offset": 0,
+ "thresholding": False,
+ "timestep_spacing": "linspace",
+ "trained_betas": None,
+ "use_karras_sigmas": False,
+ "use_lu_lambdas": False,
+ "variance_type": None
+}
+
+
+def get_sampler(
+ sampler: str,
+ kwargs: dict = None,
+ arch: str = "sd"
+):
+ sched_init_args = {}
+ if kwargs is not None:
+ sched_init_args.update(kwargs)
+
+ config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config)
+
+ if sampler.startswith("k_"):
+ sched_init_args["use_karras_sigmas"] = True
+
+ if sampler == "ddim":
+ scheduler_cls = DDIMScheduler
+ elif sampler == "ddpm": # ddpm is not supported ?
+ scheduler_cls = DDPMScheduler
+ elif sampler == "pndm":
+ scheduler_cls = PNDMScheduler
+ elif sampler == "lms" or sampler == "k_lms":
+ scheduler_cls = LMSDiscreteScheduler
+ elif sampler == "euler" or sampler == "k_euler":
+ scheduler_cls = EulerDiscreteScheduler
+ elif sampler == "euler_a":
+ scheduler_cls = EulerAncestralDiscreteScheduler
+ elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
+ scheduler_cls = DPMSolverMultistepScheduler
+ sched_init_args["algorithm_type"] = sampler.replace("k_", "")
+ elif sampler == "dpmsingle":
+ scheduler_cls = DPMSolverSinglestepScheduler
+ elif sampler == "heun":
+ scheduler_cls = HeunDiscreteScheduler
+ elif sampler == "dpm_2":
+ scheduler_cls = KDPM2DiscreteScheduler
+ elif sampler == "dpm_2_a":
+ scheduler_cls = KDPM2AncestralDiscreteScheduler
+ elif sampler == "lcm":
+ scheduler_cls = LCMScheduler
+ elif sampler == "custom_lcm":
+ scheduler_cls = CustomLCMScheduler
+ elif sampler == "flowmatch":
+ scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
+ config_to_use = {
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
+ "_diffusers_version": "0.29.0.dev0",
+ "num_train_timesteps": 1000,
+ "shift": 3.0
+ }
+ else:
+ raise ValueError(f"Sampler {sampler} not supported")
+
+
+ config = copy.deepcopy(config_to_use)
+ config.update(sched_init_args)
+
+ scheduler = scheduler_cls.from_config(config)
+
+ return scheduler
+
+
+# testing
+if __name__ == "__main__":
+ from diffusers import DiffusionPipeline
+
+ from diffusers import StableDiffusionKDiffusionPipeline
+ import torch
+ import os
+
+ inference_steps = 25
+
+ pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
+ pipe = pipe.to("cuda")
+
+ k_diffusion_model = CompVisDenoiser(model)
+
+ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
+ pipe = pipe.to("cuda")
+
+ prompt = "an astronaut riding a horse on mars"
+ pipe.set_scheduler("sample_heun")
+ generator = torch.Generator(device="cuda").manual_seed(seed)
+ image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
+
+ image.save("./astronaut_heun_k_diffusion.png")
diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..440eb4fa1220fab83a0d4ce3cdae928b873031ff
--- /dev/null
+++ b/toolkit/samplers/custom_flowmatch_sampler.py
@@ -0,0 +1,110 @@
+import math
+from typing import Union
+
+from diffusers import FlowMatchEulerDiscreteScheduler
+import torch
+
+
+class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.init_noise_sigma = 1.0
+
+ with torch.no_grad():
+ # create weights for timesteps
+ num_timesteps = 1000
+ # Bell-Shaped Mean-Normalized Timestep Weighting
+ # bsmntw? need a better name
+
+ x = torch.arange(num_timesteps, dtype=torch.float32)
+ y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
+
+ # Shift minimum to 0
+ y_shifted = y - y.min()
+
+ # Scale to make mean 1
+ bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
+
+ # only do half bell
+ hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
+
+ # flatten second half to max
+ hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
+
+ # Create linear timesteps from 1000 to 0
+ timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
+
+ self.linear_timesteps = timesteps
+ self.linear_timesteps_weights = bsmntw_weighing
+ self.linear_timesteps_weights2 = hbsmntw_weighing
+ pass
+
+ def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
+ # Get the indices of the timesteps
+ step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
+
+ # Get the weights for the timesteps
+ if v2:
+ weights = self.linear_timesteps_weights2[step_indices].flatten()
+ else:
+ weights = self.linear_timesteps_weights[step_indices].flatten()
+
+ return weights
+
+ def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
+ sigmas = self.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = self.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+
+ return sigma
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ ) -> torch.Tensor:
+ ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
+ ## Add noise according to flow matching.
+ ## zt = (1 - texp) * x + texp * z1
+
+ # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # timestep needs to be in [0, 1], we store them in [0, 1000]
+ # noisy_sample = (1 - timestep) * latent + timestep * noise
+ t_01 = (timesteps / 1000).to(original_samples.device)
+ noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
+
+ # n_dim = original_samples.ndim
+ # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
+ # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
+ return noisy_model_input
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ return sample
+
+ def set_train_timesteps(self, num_timesteps, device, linear=False):
+ if linear:
+ timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
+ self.timesteps = timesteps
+ return timesteps
+ else:
+ # distribute them closer to center. Inference distributes them as a bias toward first
+ # Generate values from 0 to 1
+ t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
+
+ # Scale and reverse the values to go from 1000 to 0
+ timesteps = ((1 - t) * 1000)
+
+ # Sort the timesteps in descending order
+ timesteps, _ = torch.sort(timesteps, descending=True)
+
+ self.timesteps = timesteps.to(device=device)
+
+ return timesteps
diff --git a/toolkit/samplers/custom_lcm_scheduler.py b/toolkit/samplers/custom_lcm_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..132052af74186b2597060d66c764c8b4ed841378
--- /dev/null
+++ b/toolkit/samplers/custom_lcm_scheduler.py
@@ -0,0 +1,553 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class LCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ denoised: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class CustomLCMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ original_inference_steps (`int`, *optional*, defaults to 50):
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_scaling (`float`, defaults to 10.0):
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
+ error at the default of `10.0` is already pretty small).
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.012,
+ beta_schedule: str = "scaled_linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ original_inference_steps: int = 50,
+ clip_sample: bool = False,
+ clip_sample_range: float = 1.0,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ timestep_scaling: float = 10.0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+ self.original_inference_steps = 50
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ self.train_timesteps = 1000
+
+ self._step_index = None
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ index_candidates = (self.timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(index_candidates) > 1:
+ step_index = index_candidates[1]
+ else:
+ step_index = index_candidates[0]
+
+ self._step_index = step_index.item()
+
+ @property
+ def step_index(self):
+ return self._step_index
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ strength: int = 1.0,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ original_inference_steps (`int`, *optional*):
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
+ schedule (which is different from the standard `diffusers` implementation). We will then take
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
+ """
+
+ original_inference_steps = self.original_inference_steps
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ original_steps = (
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
+ )
+
+ if original_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ if num_inference_steps > original_steps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
+ f" {original_steps} because the final timestep schedule will be a subset of the"
+ f" `original_inference_steps`-sized initial timestep schedule."
+ )
+
+ # LCM Timesteps Setting
+ # The skipping step parameter k from the paper.
+ k = self.config.num_train_timesteps // original_steps
+ # LCM Training/Distillation Steps Schedule
+ # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+
+ if skipping_step < 1:
+ raise ValueError(
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
+ )
+
+ # LCM Inference Steps Schedule
+ lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
+ # Select (approximately) evenly spaced indices from lcm_origin_timesteps.
+ inference_indices = np.linspace(0, len(lcm_origin_timesteps) - 1, num=num_inference_steps)
+ inference_indices = np.floor(inference_indices).astype(np.int64)
+ timesteps = lcm_origin_timesteps[inference_indices]
+
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
+
+ self._step_index = None
+
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
+ self.sigma_data = 0.5 # Default: 0.5
+ scaled_timestep = timestep * self.config.timestep_scaling
+
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # 1. get previous step value
+ prev_step_index = self.step_index + 1
+ if prev_step_index < len(self.timesteps):
+ prev_timestep = self.timesteps[prev_step_index]
+ else:
+ prev_timestep = timestep
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 3. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
+ if self.config.prediction_type == "epsilon": # noise-prediction
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
+ elif self.config.prediction_type == "sample": # x-prediction
+ predicted_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction": # v-prediction
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for `LCMScheduler`."
+ )
+
+ # 5. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
+ elif self.config.clip_sample:
+ predicted_original_sample = predicted_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 6. Denoise model output using boundary conditions
+ denoised = c_out * predicted_original_sample + c_skip * sample
+
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if self.step_index != self.num_inference_steps - 1:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
+ )
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
+ else:
+ prev_sample = denoised
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample, denoised)
+
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
\ No newline at end of file
diff --git a/toolkit/saving.py b/toolkit/saving.py
new file mode 100644
index 0000000000000000000000000000000000000000..7abc7d5058d347b06cdcfcd4b452223e7920b346
--- /dev/null
+++ b/toolkit/saving.py
@@ -0,0 +1,330 @@
+import json
+import os
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Literal, Optional, Union
+
+import torch
+from safetensors.torch import load_file, save_file
+
+from toolkit.train_tools import get_torch_dtype
+from toolkit.paths import KEYMAPS_ROOT
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import StableDiffusion
+
+
+def get_slices_from_string(s: str) -> tuple:
+ slice_strings = s.split(',')
+ slices = [eval(f"slice({component.strip()})") for component in slice_strings]
+ return tuple(slices)
+
+
+def convert_state_dict_to_ldm_with_mapping(
+ diffusers_state_dict: 'OrderedDict',
+ mapping_path: str,
+ base_path: Union[str, None] = None,
+ device: str = 'cpu',
+ dtype: torch.dtype = torch.float32
+) -> 'OrderedDict':
+ converted_state_dict = OrderedDict()
+
+ # load mapping
+ with open(mapping_path, 'r') as f:
+ mapping = json.load(f, object_pairs_hook=OrderedDict)
+
+ # keep track of keys not matched
+ ldm_matched_keys = []
+ diffusers_matched_keys = []
+
+ ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
+ ldm_diffusers_shape_map = mapping['ldm_diffusers_shape_map']
+ ldm_diffusers_operator_map = mapping['ldm_diffusers_operator_map']
+
+ # load base if it exists
+ # the base just has come keys like timing ids and stuff diffusers doesn't have or they don't match
+ if base_path is not None:
+ converted_state_dict = load_file(base_path, device)
+ # convert to the right dtype
+ for key in converted_state_dict:
+ converted_state_dict[key] = converted_state_dict[key].to(device, dtype=dtype)
+
+ # process operators first
+ for ldm_key in ldm_diffusers_operator_map:
+ # if the key cat is in the ldm key, we need to process it
+ if 'cat' in ldm_diffusers_operator_map[ldm_key]:
+ cat_list = []
+ for diffusers_key in ldm_diffusers_operator_map[ldm_key]['cat']:
+ cat_list.append(diffusers_state_dict[diffusers_key].detach())
+ converted_state_dict[ldm_key] = torch.cat(cat_list, dim=0).to(device, dtype=dtype)
+ diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['cat'])
+ ldm_matched_keys.append(ldm_key)
+ if 'slice' in ldm_diffusers_operator_map[ldm_key]:
+ tensor_to_slice = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][0]]
+ slice_text = diffusers_state_dict[ldm_diffusers_operator_map[ldm_key]['slice'][1]]
+ converted_state_dict[ldm_key] = tensor_to_slice[get_slices_from_string(slice_text)].detach().to(device,
+ dtype=dtype)
+ diffusers_matched_keys.extend(ldm_diffusers_operator_map[ldm_key]['slice'])
+ ldm_matched_keys.append(ldm_key)
+
+ # process the rest of the keys
+ for ldm_key in ldm_diffusers_keymap:
+ # if the key is in the ldm key, we need to process it
+ if ldm_diffusers_keymap[ldm_key] in diffusers_state_dict:
+ tensor = diffusers_state_dict[ldm_diffusers_keymap[ldm_key]].detach().to(device, dtype=dtype)
+ # see if we need to reshape
+ if ldm_key in ldm_diffusers_shape_map:
+ tensor = tensor.view(ldm_diffusers_shape_map[ldm_key][0])
+ converted_state_dict[ldm_key] = tensor
+ diffusers_matched_keys.append(ldm_diffusers_keymap[ldm_key])
+ ldm_matched_keys.append(ldm_key)
+
+ # see if any are missing from know mapping
+ mapped_diffusers_keys = list(ldm_diffusers_keymap.values())
+ mapped_ldm_keys = list(ldm_diffusers_keymap.keys())
+
+ missing_diffusers_keys = [x for x in mapped_diffusers_keys if x not in diffusers_matched_keys]
+ missing_ldm_keys = [x for x in mapped_ldm_keys if x not in ldm_matched_keys]
+
+ if len(missing_diffusers_keys) > 0:
+ print(f"WARNING!!!! Missing {len(missing_diffusers_keys)} diffusers keys")
+ print(missing_diffusers_keys)
+ if len(missing_ldm_keys) > 0:
+ print(f"WARNING!!!! Missing {len(missing_ldm_keys)} ldm keys")
+ print(missing_ldm_keys)
+
+ return converted_state_dict
+
+
+def get_ldm_state_dict_from_diffusers(
+ state_dict: 'OrderedDict',
+ sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega', 'sdxl_refiner'] = '2',
+ device='cpu',
+ dtype=get_torch_dtype('fp32'),
+):
+ if sd_version == '1':
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd1.json')
+ elif sd_version == '2':
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sd2.json')
+ elif sd_version == 'sdxl':
+ # load our base
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_sdxl.json')
+ elif sd_version == 'ssd':
+ # load our base
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_ssd.json')
+ elif sd_version == 'vega':
+ # load our base
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_vega.json')
+ elif sd_version == 'sdxl_refiner':
+ # load our base
+ base_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner_ldm_base.safetensors')
+ mapping_path = os.path.join(KEYMAPS_ROOT, 'stable_diffusion_refiner.json')
+ else:
+ raise ValueError(f"Invalid sd_version {sd_version}")
+
+ # convert the state dict
+ return convert_state_dict_to_ldm_with_mapping(
+ state_dict,
+ mapping_path,
+ base_path,
+ device=device,
+ dtype=dtype
+ )
+
+
+def save_ldm_model_from_diffusers(
+ sd: 'StableDiffusion',
+ output_file: str,
+ meta: 'OrderedDict',
+ save_dtype=get_torch_dtype('fp16'),
+ sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
+):
+ converted_state_dict = get_ldm_state_dict_from_diffusers(
+ sd.state_dict(),
+ sd_version,
+ device='cpu',
+ dtype=save_dtype
+ )
+
+ # make sure parent folder exists
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ save_file(converted_state_dict, output_file, metadata=meta)
+
+
+def save_lora_from_diffusers(
+ lora_state_dict: 'OrderedDict',
+ output_file: str,
+ meta: 'OrderedDict',
+ save_dtype=get_torch_dtype('fp16'),
+ sd_version: Literal['1', '2', 'sdxl', 'ssd', 'vega'] = '2'
+):
+ converted_state_dict = OrderedDict()
+ # only handle sxdxl for now
+ if sd_version != 'sdxl' and sd_version != 'ssd' and sd_version != 'vega':
+ raise ValueError(f"Invalid sd_version {sd_version}")
+ for key, value in lora_state_dict.items():
+ # todo verify if this works with ssd
+ # test encoders share keys for some reason
+ if key.begins_with('lora_te'):
+ converted_state_dict[key] = value.detach().to('cpu', dtype=save_dtype)
+ else:
+ converted_key = key
+
+ # make sure parent folder exists
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ save_file(converted_state_dict, output_file, metadata=meta)
+
+
+def save_t2i_from_diffusers(
+ t2i_state_dict: 'OrderedDict',
+ output_file: str,
+ meta: 'OrderedDict',
+ dtype=get_torch_dtype('fp16'),
+):
+ # todo: test compatibility with non diffusers
+ converted_state_dict = OrderedDict()
+ for key, value in t2i_state_dict.items():
+ converted_state_dict[key] = value.detach().to('cpu', dtype=dtype)
+
+ # make sure parent folder exists
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ save_file(converted_state_dict, output_file, metadata=meta)
+
+
+def load_t2i_model(
+ path_to_file,
+ device: Union[str] = 'cpu',
+ dtype: torch.dtype = torch.float32
+):
+ raw_state_dict = load_file(path_to_file, device)
+ converted_state_dict = OrderedDict()
+ for key, value in raw_state_dict.items():
+ # todo see if we need to convert dict
+ converted_state_dict[key] = value.detach().to(device, dtype=dtype)
+ return converted_state_dict
+
+
+
+
+def save_ip_adapter_from_diffusers(
+ combined_state_dict: 'OrderedDict',
+ output_file: str,
+ meta: 'OrderedDict',
+ dtype=get_torch_dtype('fp16'),
+ direct_save: bool = False
+):
+ # todo: test compatibility with non diffusers
+
+ converted_state_dict = OrderedDict()
+ for module_name, state_dict in combined_state_dict.items():
+ if direct_save:
+ converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype)
+ else:
+ for key, value in state_dict.items():
+ converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
+
+ # make sure parent folder exists
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ save_file(converted_state_dict, output_file, metadata=meta)
+
+
+def load_ip_adapter_model(
+ path_to_file,
+ device: Union[str] = 'cpu',
+ dtype: torch.dtype = torch.float32,
+ direct_load: bool = False
+):
+ # check if it is safetensors or checkpoint
+ if path_to_file.endswith('.safetensors'):
+ raw_state_dict = load_file(path_to_file, device)
+ combined_state_dict = OrderedDict()
+ if direct_load:
+ return raw_state_dict
+ for combo_key, value in raw_state_dict.items():
+ key_split = combo_key.split('.')
+ module_name = key_split.pop(0)
+ if module_name not in combined_state_dict:
+ combined_state_dict[module_name] = OrderedDict()
+ combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
+ return combined_state_dict
+ else:
+ return torch.load(path_to_file, map_location=device)
+
+def load_custom_adapter_model(
+ path_to_file,
+ device: Union[str] = 'cpu',
+ dtype: torch.dtype = torch.float32
+):
+ # check if it is safetensors or checkpoint
+ if path_to_file.endswith('.safetensors'):
+ raw_state_dict = load_file(path_to_file, device)
+ combined_state_dict = OrderedDict()
+ device = device if isinstance(device, torch.device) else torch.device(device)
+ dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype)
+ for combo_key, value in raw_state_dict.items():
+ key_split = combo_key.split('.')
+ module_name = key_split.pop(0)
+ if module_name not in combined_state_dict:
+ combined_state_dict[module_name] = OrderedDict()
+ combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
+ return combined_state_dict
+ else:
+ return torch.load(path_to_file, map_location=device)
+
+
+def get_lora_keymap_from_model_keymap(model_keymap: 'OrderedDict') -> 'OrderedDict':
+ lora_keymap = OrderedDict()
+
+ # see if we have dual text encoders " a key that starts with conditioner.embedders.1
+ has_dual_text_encoders = False
+ for key in model_keymap:
+ if key.startswith('conditioner.embedders.1'):
+ has_dual_text_encoders = True
+ break
+ # map through the keys and values
+ for key, value in model_keymap.items():
+ # ignore bias weights
+ if key.endswith('bias'):
+ continue
+ if key.endswith('.weight'):
+ # remove the .weight
+ key = key[:-7]
+ if value.endswith(".weight"):
+ # remove the .weight
+ value = value[:-7]
+
+ # unet for all
+ key = key.replace('model.diffusion_model', 'lora_unet')
+ if value.startswith('unet'):
+ value = f"lora_{value}"
+
+ # text encoder
+ if has_dual_text_encoders:
+ key = key.replace('conditioner.embedders.0', 'lora_te1')
+ key = key.replace('conditioner.embedders.1', 'lora_te2')
+ if value.startswith('te0') or value.startswith('te1'):
+ value = f"lora_{value}"
+ value.replace('lora_te1', 'lora_te2')
+ value.replace('lora_te0', 'lora_te1')
+
+ key = key.replace('cond_stage_model.transformer', 'lora_te')
+
+ if value.startswith('te_'):
+ value = f"lora_{value}"
+
+ # replace periods with underscores
+ key = key.replace('.', '_')
+ value = value.replace('.', '_')
+
+ # add all the weights
+ lora_keymap[f"{key}.lora_down.weight"] = f"{value}.lora_down.weight"
+ lora_keymap[f"{key}.lora_down.bias"] = f"{value}.lora_down.bias"
+ lora_keymap[f"{key}.lora_up.weight"] = f"{value}.lora_up.weight"
+ lora_keymap[f"{key}.lora_up.bias"] = f"{value}.lora_up.bias"
+ lora_keymap[f"{key}.alpha"] = f"{value}.alpha"
+
+ return lora_keymap
diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f8f61aeb8f63b12ee8f8f2800385c11ec3b7bf
--- /dev/null
+++ b/toolkit/scheduler.py
@@ -0,0 +1,57 @@
+import torch
+from typing import Optional
+from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup
+
+
+def get_lr_scheduler(
+ name: Optional[str],
+ optimizer: torch.optim.Optimizer,
+ **kwargs,
+):
+ if name == "cosine":
+ if 'total_iters' in kwargs:
+ kwargs['T_max'] = kwargs.pop('total_iters')
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer, **kwargs
+ )
+ elif name == "cosine_with_restarts":
+ if 'total_iters' in kwargs:
+ kwargs['T_0'] = kwargs.pop('total_iters')
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ optimizer, **kwargs
+ )
+ elif name == "step":
+
+ return torch.optim.lr_scheduler.StepLR(
+ optimizer, **kwargs
+ )
+ elif name == "constant":
+ if 'factor' not in kwargs:
+ kwargs['factor'] = 1.0
+
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, **kwargs)
+ elif name == "linear":
+
+ return torch.optim.lr_scheduler.LinearLR(
+ optimizer, **kwargs
+ )
+ elif name == 'constant_with_warmup':
+ # see if num_warmup_steps is in kwargs
+ if 'num_warmup_steps' not in kwargs:
+ print(f"WARNING: num_warmup_steps not in kwargs. Using default value of 1000")
+ kwargs['num_warmup_steps'] = 1000
+ del kwargs['total_iters']
+ return get_constant_schedule_with_warmup(optimizer, **kwargs)
+ else:
+ # try to use a diffusers scheduler
+ print(f"Trying to use diffusers scheduler {name}")
+ try:
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ return schedule_func(optimizer, **kwargs)
+ except Exception as e:
+ print(e)
+ pass
+ raise ValueError(
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
+ )
diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eeecc323fefb7b06fdaf30ff9f80f5399b0ce09
--- /dev/null
+++ b/toolkit/sd_device_states_presets.py
@@ -0,0 +1,107 @@
+from typing import Union
+
+import torch
+import copy
+
+empty_preset = {
+ 'vae': {
+ 'training': False,
+ 'device': 'cpu',
+ },
+ 'unet': {
+ 'training': False,
+ 'requires_grad': False,
+ 'device': 'cpu',
+ },
+ 'text_encoder': {
+ 'training': False,
+ 'requires_grad': False,
+ 'device': 'cpu',
+ },
+ 'adapter': {
+ 'training': False,
+ 'requires_grad': False,
+ 'device': 'cpu',
+ },
+ 'refiner_unet': {
+ 'training': False,
+ 'requires_grad': False,
+ 'device': 'cpu',
+ },
+}
+
+
+def get_train_sd_device_state_preset(
+ device: Union[str, torch.device],
+ train_unet: bool = False,
+ train_text_encoder: bool = False,
+ cached_latents: bool = False,
+ train_lora: bool = False,
+ train_adapter: bool = False,
+ train_embedding: bool = False,
+ train_decorator: bool = False,
+ train_refiner: bool = False,
+ unload_text_encoder: bool = False,
+ require_grads: bool = True,
+):
+ preset = copy.deepcopy(empty_preset)
+ if not cached_latents:
+ preset['vae']['device'] = device
+
+ if train_unet:
+ preset['unet']['training'] = True
+ preset['unet']['requires_grad'] = require_grads
+ preset['unet']['device'] = device
+ else:
+ preset['unet']['device'] = device
+
+ if train_text_encoder:
+ preset['text_encoder']['training'] = True
+ preset['text_encoder']['requires_grad'] = require_grads
+ preset['text_encoder']['device'] = device
+ else:
+ preset['text_encoder']['device'] = device
+
+ if train_embedding:
+ preset['text_encoder']['training'] = True
+ preset['text_encoder']['requires_grad'] = require_grads
+ preset['text_encoder']['training'] = True
+ preset['unet']['training'] = True
+
+ if train_refiner:
+ preset['refiner_unet']['training'] = True
+ preset['refiner_unet']['requires_grad'] = require_grads
+ preset['refiner_unet']['device'] = device
+ # if not training unet, move that to cpu
+ if not train_unet:
+ preset['unet']['device'] = 'cpu'
+
+ if train_lora:
+ # preset['text_encoder']['requires_grad'] = False
+ preset['unet']['requires_grad'] = False
+ if train_refiner:
+ preset['refiner_unet']['requires_grad'] = False
+
+ if train_adapter:
+ preset['adapter']['requires_grad'] = require_grads
+ preset['adapter']['training'] = True
+ preset['adapter']['device'] = device
+ preset['unet']['training'] = True
+ preset['unet']['requires_grad'] = False
+ preset['unet']['device'] = device
+ preset['text_encoder']['device'] = device
+
+ if train_decorator:
+ preset['text_encoder']['training'] = False
+ preset['text_encoder']['requires_grad'] = False
+ preset['text_encoder']['device'] = device
+ preset['unet']['training'] = True
+ preset['unet']['requires_grad'] = False
+ preset['unet']['device'] = device
+
+ if unload_text_encoder:
+ preset['text_encoder']['training'] = False
+ preset['text_encoder']['requires_grad'] = False
+ preset['text_encoder']['device'] = 'cpu'
+
+ return preset
diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..23439fbc310d5153528e99114f31ca5c2988de0a
--- /dev/null
+++ b/toolkit/stable_diffusion_model.py
@@ -0,0 +1,2754 @@
+import copy
+import gc
+import json
+import random
+import shutil
+import typing
+from typing import Union, List, Literal, Iterator
+import sys
+import os
+from collections import OrderedDict
+import copy
+import yaml
+from PIL import Image
+from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \
+ ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+from safetensors.torch import save_file, load_file
+from torch import autocast
+from torch.nn import Parameter
+from torch.utils.checkpoint import checkpoint
+from tqdm import tqdm
+from torchvision.transforms import Resize, transforms
+
+from toolkit.assistant_lora import load_assistant_lora_from_path
+from toolkit.clip_vision_adapter import ClipVisionAdapter
+from toolkit.custom_adapter import CustomAdapter
+from toolkit.dequantize import patch_dequantization_on_save
+from toolkit.ip_adapter import IPAdapter
+from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
+ convert_vae_state_dict, load_vae
+from toolkit import train_tools
+from toolkit.config_modules import ModelConfig, GenerateImageConfig
+from toolkit.metadata import get_meta_for_safetensors
+from toolkit.models.decorator import Decorator
+from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
+from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
+from toolkit.reference_adapter import ReferenceAdapter
+from toolkit.sampler import get_sampler
+from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
+from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
+from toolkit.sd_device_states_presets import empty_preset
+from toolkit.train_tools import get_torch_dtype, apply_noise_offset
+from einops import rearrange, repeat
+import torch
+from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
+ StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
+from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
+ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
+ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
+ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
+ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
+ FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel
+import diffusers
+from diffusers import \
+ AutoencoderKL, \
+ UNet2DConditionModel
+from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
+from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
+
+from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
+from huggingface_hub import hf_hub_download
+from toolkit.models.flux import bypass_flux_guidance, restore_flux_guidance
+
+from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from toolkit.lora_special import LoRASpecialNetwork
+
+# tell it to shut up
+diffusers.logging.set_verbosity(diffusers.logging.ERROR)
+
+SD_PREFIX_VAE = "vae"
+SD_PREFIX_UNET = "unet"
+SD_PREFIX_REFINER_UNET = "refiner_unet"
+SD_PREFIX_TEXT_ENCODER = "te"
+
+SD_PREFIX_TEXT_ENCODER1 = "te0"
+SD_PREFIX_TEXT_ENCODER2 = "te1"
+
+# prefixed diffusers keys
+DO_NOT_TRAIN_WEIGHTS = [
+ "unet_time_embedding.linear_1.bias",
+ "unet_time_embedding.linear_1.weight",
+ "unet_time_embedding.linear_2.bias",
+ "unet_time_embedding.linear_2.weight",
+ "refiner_unet_time_embedding.linear_1.bias",
+ "refiner_unet_time_embedding.linear_1.weight",
+ "refiner_unet_time_embedding.linear_2.bias",
+ "refiner_unet_time_embedding.linear_2.weight",
+]
+
+DeviceStatePreset = Literal['cache_latents', 'generate']
+
+
+class BlankNetwork:
+
+ def __init__(self):
+ self.multiplier = 1.0
+ self.is_active = True
+ self.is_merged_in = False
+ self.can_merge_in = False
+
+ def __enter__(self):
+ self.is_active = True
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.is_active = False
+
+
+def flush():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
+# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
+
+
+
+class StableDiffusion:
+
+ def __init__(
+ self,
+ device,
+ model_config: ModelConfig,
+ dtype='fp16',
+ custom_pipeline=None,
+ noise_scheduler=None,
+ quantize_device=None,
+ ):
+ self.custom_pipeline = custom_pipeline
+ self.device = device
+ self.dtype = dtype
+ self.torch_dtype = get_torch_dtype(dtype)
+ self.device_torch = torch.device(self.device)
+
+ self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
+ model_config.vae_device)
+ self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
+
+ self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
+ model_config.te_device)
+ self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
+
+ self.model_config = model_config
+ self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
+
+ self.device_state = None
+
+ self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
+ self.vae: Union[None, 'AutoencoderKL']
+ self.unet: Union[None, 'UNet2DConditionModel']
+ self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
+ self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
+ self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
+
+ self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
+ self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None
+
+ # sdxl stuff
+ self.logit_scale = None
+ self.ckppt_info = None
+ self.is_loaded = False
+
+ # to hold network if there is one
+ self.network = None
+ self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
+ self.decorator: Union[Decorator, None] = None
+ self.is_xl = model_config.is_xl
+ self.is_v2 = model_config.is_v2
+ self.is_ssd = model_config.is_ssd
+ self.is_v3 = model_config.is_v3
+ self.is_vega = model_config.is_vega
+ self.is_pixart = model_config.is_pixart
+ self.is_auraflow = model_config.is_auraflow
+ self.is_flux = model_config.is_flux
+
+ self.use_text_encoder_1 = model_config.use_text_encoder_1
+ self.use_text_encoder_2 = model_config.use_text_encoder_2
+
+ self.config_file = None
+
+ self.is_flow_matching = False
+ if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
+ self.is_flow_matching = True
+
+ self.quantize_device = quantize_device if quantize_device is not None else self.device
+ self.low_vram = self.model_config.low_vram
+
+ # merge in and preview active with -1 weight
+ self.invert_assistant_lora = False
+
+ def load_model(self):
+ if self.is_loaded:
+ return
+ dtype = get_torch_dtype(self.dtype)
+
+ # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
+ # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
+ # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
+ # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
+
+ model_path = self.model_config.name_or_path
+ if 'civitai.com' in self.model_config.name_or_path:
+ # load is a civit ai model, use the loader.
+ from toolkit.civitai import get_model_path_from_url
+ model_path = get_model_path_from_url(self.model_config.name_or_path)
+
+ load_args = {}
+ if self.noise_scheduler:
+ load_args['scheduler'] = self.noise_scheduler
+
+ if self.model_config.vae_path is not None:
+ load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
+ if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
+ if self.custom_pipeline is not None:
+ pipln = self.custom_pipeline
+ else:
+ pipln = StableDiffusionXLPipeline
+ # pipln = StableDiffusionKDiffusionXLPipeline
+
+ # see if path exists
+ if not os.path.exists(model_path) or os.path.isdir(model_path):
+ # try to load with default diffusers
+ pipe = pipln.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ # variant="fp16",
+ use_safetensors=True,
+ **load_args
+ )
+ else:
+ pipe = pipln.from_single_file(
+ model_path,
+ device=self.device_torch,
+ torch_dtype=self.torch_dtype,
+ )
+
+ if 'vae' in load_args and load_args['vae'] is not None:
+ pipe.vae = load_args['vae']
+ flush()
+
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
+ tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
+ for text_encoder in text_encoders:
+ text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ text_encoder = text_encoders
+
+ pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
+
+ if self.model_config.experimental_xl:
+ print("Experimental XL mode enabled")
+ print("Loading and injecting alt weights")
+ # load the mismatched weight and force it in
+ raw_state_dict = load_file(model_path)
+ replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone()
+ del raw_state_dict
+ # get state dict for for 2nd text encoder
+ te1_state_dict = text_encoders[1].state_dict()
+ # replace weight with mismatched weight
+ te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
+ flush()
+ print("Injecting alt weights")
+ elif self.model_config.is_v3:
+ if self.custom_pipeline is not None:
+ pipln = self.custom_pipeline
+ else:
+ pipln = StableDiffusion3Pipeline
+
+ print("Loading SD3 model")
+ # assume it is the large model
+ base_model_path = "stabilityai/stable-diffusion-3.5-large"
+ print("Loading transformer")
+ subfolder = 'transformer'
+ transformer_path = model_path
+ # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
+ if os.path.exists(transformer_path):
+ subfolder = None
+ transformer_path = os.path.join(transformer_path, 'transformer')
+ # check if the path is a full checkpoint.
+ te_folder_path = os.path.join(model_path, 'text_encoder')
+ # if we have the te, this folder is a full checkpoint, use it as the base
+ if os.path.exists(te_folder_path):
+ base_model_path = model_path
+ else:
+ # is remote use whatever path we were given
+ base_model_path = model_path
+
+ transformer = SD3Transformer2DModel.from_pretrained(
+ transformer_path,
+ subfolder=subfolder,
+ torch_dtype=dtype,
+ )
+ if not self.low_vram:
+ # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
+ transformer.to(torch.device(self.quantize_device), dtype=dtype)
+ flush()
+
+ if self.model_config.lora_path is not None:
+ raise ValueError("LoRA is not supported for SD3 models currently")
+
+ if self.model_config.quantize:
+ quantization_type = qfloat8
+ print("Quantizing transformer")
+ quantize(transformer, weights=quantization_type)
+ freeze(transformer)
+ transformer.to(self.device_torch)
+ else:
+ transformer.to(self.device_torch, dtype=dtype)
+
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
+ print("Loading vae")
+ vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
+ flush()
+
+ print("Loading t5")
+ tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype)
+ text_encoder_3 = T5EncoderModel.from_pretrained(
+ base_model_path,
+ subfolder="text_encoder_3",
+ torch_dtype=dtype
+ )
+
+ text_encoder_3.to(self.device_torch, dtype=dtype)
+ flush()
+
+ if self.model_config.quantize:
+ print("Quantizing T5")
+ quantize(text_encoder_3, weights=qfloat8)
+ freeze(text_encoder_3)
+ flush()
+
+
+ # see if path exists
+ if not os.path.exists(model_path) or os.path.isdir(model_path):
+ try:
+ # try to load with default diffusers
+ pipe = pipln.from_pretrained(
+ base_model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ tokenizer_3=tokenizer_3,
+ text_encoder_3=text_encoder_3,
+ transformer=transformer,
+ # variant="fp16",
+ use_safetensors=True,
+ repo_type="model",
+ ignore_patterns=["*.md", "*..gitattributes"],
+ **load_args
+ )
+ except Exception as e:
+ print(f"Error loading from pretrained: {e}")
+ raise e
+
+ else:
+ pipe = pipln.from_single_file(
+ model_path,
+ transformer=transformer,
+ device=self.device_torch,
+ torch_dtype=self.torch_dtype,
+ tokenizer_3=tokenizer_3,
+ text_encoder_3=text_encoder_3,
+ **load_args
+ )
+
+ flush()
+
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
+ tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
+ # replace the to function with a no-op since it throws an error instead of a warning
+ # text_encoders[2].to = lambda *args, **kwargs: None
+ for text_encoder in text_encoders:
+ text_encoder.to(self.device_torch, dtype=dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ text_encoder = text_encoders
+
+
+ elif self.model_config.is_pixart:
+ te_kwargs = {}
+ # handle quantization of TE
+ te_is_quantized = False
+ if self.model_config.text_encoder_bits == 8:
+ te_kwargs['load_in_8bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+ elif self.model_config.text_encoder_bits == 4:
+ te_kwargs['load_in_4bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+
+ main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
+ if self.model_config.is_pixart_sigma:
+ main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
+
+ main_model_path = model_path
+
+ # load the TE in 8bit mode
+ text_encoder = T5EncoderModel.from_pretrained(
+ main_model_path,
+ subfolder="text_encoder",
+ torch_dtype=self.torch_dtype,
+ **te_kwargs
+ )
+
+ # load the transformer
+ subfolder = "transformer"
+ # check if it is just the unet
+ if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
+ subfolder = None
+
+ if te_is_quantized:
+ # replace the to function with a no-op since it throws an error instead of a warning
+ text_encoder.to = lambda *args, **kwargs: None
+
+ text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
+
+ if self.model_config.is_pixart_sigma:
+ # load the transformer only from the save
+ transformer = Transformer2DModel.from_pretrained(
+ model_path if self.model_config.unet_path is None else self.model_config.unet_path,
+ torch_dtype=self.torch_dtype,
+ subfolder='transformer'
+ )
+ pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
+ main_model_path,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ dtype=dtype,
+ device=self.device_torch,
+ **load_args
+ )
+
+ else:
+
+ # load the transformer only from the save
+ transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype,
+ subfolder=subfolder)
+ pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
+ main_model_path,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ dtype=dtype,
+ device=self.device_torch,
+ **load_args
+ ).to(self.device_torch)
+
+ if self.model_config.unet_sample_size is not None:
+ pipe.transformer.config.sample_size = self.model_config.unet_sample_size
+ pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
+
+ flush()
+ # text_encoder = pipe.text_encoder
+ # text_encoder.to(self.device_torch, dtype=dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
+ tokenizer = pipe.tokenizer
+
+ pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
+ if self.noise_scheduler is None:
+ self.noise_scheduler = pipe.scheduler
+
+
+ elif self.model_config.is_auraflow:
+ te_kwargs = {}
+ # handle quantization of TE
+ te_is_quantized = False
+ if self.model_config.text_encoder_bits == 8:
+ te_kwargs['load_in_8bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+ elif self.model_config.text_encoder_bits == 4:
+ te_kwargs['load_in_4bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+
+ main_model_path = model_path
+
+ # load the TE in 8bit mode
+ text_encoder = UMT5EncoderModel.from_pretrained(
+ main_model_path,
+ subfolder="text_encoder",
+ torch_dtype=self.torch_dtype,
+ **te_kwargs
+ )
+
+ # load the transformer
+ subfolder = "transformer"
+ # check if it is just the unet
+ if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
+ subfolder = None
+
+ if te_is_quantized:
+ # replace the to function with a no-op since it throws an error instead of a warning
+ text_encoder.to = lambda *args, **kwargs: None
+
+ # load the transformer only from the save
+ transformer = AuraFlowTransformer2DModel.from_pretrained(
+ model_path if self.model_config.unet_path is None else self.model_config.unet_path,
+ torch_dtype=self.torch_dtype,
+ subfolder='transformer'
+ )
+ pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained(
+ main_model_path,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ dtype=dtype,
+ device=self.device_torch,
+ **load_args
+ )
+
+ pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
+
+ # patch auraflow so it can handle other aspect ratios
+ # patch_auraflow_pos_embed(pipe.transformer.pos_embed)
+
+ flush()
+ # text_encoder = pipe.text_encoder
+ # text_encoder.to(self.device_torch, dtype=dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
+ tokenizer = pipe.tokenizer
+
+ elif self.model_config.is_flux:
+ print("Loading Flux model")
+ # base_model_path = "black-forest-labs/FLUX.1-schnell"
+ base_model_path = self.model_config.name_or_path_original
+ print("Loading transformer")
+ subfolder = 'transformer'
+ transformer_path = model_path
+ local_files_only = False
+ # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
+ if os.path.exists(transformer_path):
+ subfolder = None
+ transformer_path = os.path.join(transformer_path, 'transformer')
+ # check if the path is a full checkpoint.
+ te_folder_path = os.path.join(model_path, 'text_encoder')
+ # if we have the te, this folder is a full checkpoint, use it as the base
+ if os.path.exists(te_folder_path):
+ base_model_path = model_path
+
+ transformer = FluxTransformer2DModel.from_pretrained(
+ transformer_path,
+ subfolder=subfolder,
+ torch_dtype=dtype,
+ # low_cpu_mem_usage=False,
+ # device_map=None
+ )
+ if not self.low_vram:
+ # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
+ transformer.to(torch.device(self.quantize_device), dtype=dtype)
+ flush()
+
+ if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
+ if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None:
+ raise ValueError("Cannot load both assistant lora and inference lora at the same time")
+
+ if self.model_config.lora_path:
+ raise ValueError("Cannot load both assistant lora and lora at the same time")
+
+ if not self.is_flux:
+ raise ValueError("Assistant/ inference lora is only supported for flux models currently")
+
+ load_lora_path = self.model_config.inference_lora_path
+ if load_lora_path is None:
+ load_lora_path = self.model_config.assistant_lora_path
+
+ if os.path.isdir(load_lora_path):
+ load_lora_path = os.path.join(
+ load_lora_path, "pytorch_lora_weights.safetensors"
+ )
+ elif not os.path.exists(load_lora_path):
+ print(f"Grabbing lora from the hub: {load_lora_path}")
+ new_lora_path = hf_hub_download(
+ load_lora_path,
+ filename="pytorch_lora_weights.safetensors"
+ )
+ # replace the path
+ load_lora_path = new_lora_path
+
+ if self.model_config.inference_lora_path is not None:
+ self.model_config.inference_lora_path = new_lora_path
+ if self.model_config.assistant_lora_path is not None:
+ self.model_config.assistant_lora_path = new_lora_path
+
+ if self.model_config.assistant_lora_path is not None:
+ # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
+ # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
+ # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
+ # so we will merge in now and sample with -1 weight later
+ self.invert_assistant_lora = True
+ # trigger it to get merged in
+ self.model_config.lora_path = self.model_config.assistant_lora_path
+
+ if self.model_config.lora_path is not None:
+ print("Fusing in LoRA")
+ # need the pipe for peft
+ pipe: FluxPipeline = FluxPipeline(
+ scheduler=None,
+ text_encoder=None,
+ tokenizer=None,
+ text_encoder_2=None,
+ tokenizer_2=None,
+ vae=None,
+ transformer=transformer,
+ )
+ if self.low_vram:
+ # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts
+ # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu
+ # we are going to separate it into the two transformer blocks one at a time
+
+ lora_state_dict = load_file(self.model_config.lora_path)
+ single_transformer_lora = {}
+ single_block_key = "transformer.single_transformer_blocks."
+ double_transformer_lora = {}
+ double_block_key = "transformer.transformer_blocks."
+ for key, value in lora_state_dict.items():
+ if single_block_key in key:
+ single_transformer_lora[key] = value
+ elif double_block_key in key:
+ double_transformer_lora[key] = value
+ else:
+ raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode")
+
+ # double blocks
+ transformer.transformer_blocks = transformer.transformer_blocks.to(
+ torch.device(self.quantize_device), dtype=dtype
+ )
+ pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double")
+ pipe.fuse_lora()
+ pipe.unload_lora_weights()
+ transformer.transformer_blocks = transformer.transformer_blocks.to(
+ 'cpu', dtype=dtype
+ )
+
+ # single blocks
+ transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
+ torch.device(self.quantize_device), dtype=dtype
+ )
+ pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single")
+ pipe.fuse_lora()
+ pipe.unload_lora_weights()
+ transformer.single_transformer_blocks = transformer.single_transformer_blocks.to(
+ 'cpu', dtype=dtype
+ )
+
+ # cleanup
+ del single_transformer_lora
+ del double_transformer_lora
+ del lora_state_dict
+ flush()
+
+ else:
+ # need the pipe to do this unfortunately for now
+ # we have to fuse in the weights before quantizing
+ pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
+ pipe.fuse_lora()
+ # unfortunately, not an easier way with peft
+ pipe.unload_lora_weights()
+ flush()
+
+ if self.model_config.quantize:
+ # patch the state dict method
+ patch_dequantization_on_save(transformer)
+ quantization_type = qfloat8
+ print("Quantizing transformer")
+ quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
+ freeze(transformer)
+ transformer.to(self.device_torch)
+ else:
+ transformer.to(self.device_torch, dtype=dtype)
+
+ flush()
+
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
+ print("Loading vae")
+ vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
+ flush()
+
+ print("Loading t5")
+ tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
+ text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
+ torch_dtype=dtype)
+
+ text_encoder_2.to(self.device_torch, dtype=dtype)
+ flush()
+
+ if self.model_config.quantize:
+ print("Quantizing T5")
+ quantize(text_encoder_2, weights=qfloat8)
+ freeze(text_encoder_2)
+ flush()
+
+ print("Loading clip")
+ text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
+ tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
+ text_encoder.to(self.device_torch, dtype=dtype)
+
+ print("making pipe")
+ pipe: FluxPipeline = FluxPipeline(
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=None,
+ tokenizer_2=tokenizer_2,
+ vae=vae,
+ transformer=None,
+ )
+ pipe.text_encoder_2 = text_encoder_2
+ pipe.transformer = transformer
+
+ print("preparing")
+
+ text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
+ tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
+
+ pipe.transformer = pipe.transformer.to(self.device_torch)
+
+ flush()
+ text_encoder[0].to(self.device_torch)
+ text_encoder[0].requires_grad_(False)
+ text_encoder[0].eval()
+ text_encoder[1].to(self.device_torch)
+ text_encoder[1].requires_grad_(False)
+ text_encoder[1].eval()
+ pipe.transformer = pipe.transformer.to(self.device_torch)
+ flush()
+ else:
+ if self.custom_pipeline is not None:
+ pipln = self.custom_pipeline
+ else:
+ pipln = StableDiffusionPipeline
+
+ if self.model_config.text_encoder_bits < 16:
+ # this is only supported for T5 models for now
+ te_kwargs = {}
+ # handle quantization of TE
+ te_is_quantized = False
+ if self.model_config.text_encoder_bits == 8:
+ te_kwargs['load_in_8bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+ elif self.model_config.text_encoder_bits == 4:
+ te_kwargs['load_in_4bit'] = True
+ te_kwargs['device_map'] = "auto"
+ te_is_quantized = True
+
+ text_encoder = T5EncoderModel.from_pretrained(
+ model_path,
+ subfolder="text_encoder",
+ torch_dtype=self.te_torch_dtype,
+ **te_kwargs
+ )
+ # replace the to function with a no-op since it throws an error instead of a warning
+ text_encoder.to = lambda *args, **kwargs: None
+
+ load_args['text_encoder'] = text_encoder
+
+ # see if path exists
+ if not os.path.exists(model_path) or os.path.isdir(model_path):
+ # try to load with default diffusers
+ pipe = pipln.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ load_safety_checker=False,
+ requires_safety_checker=False,
+ safety_checker=None,
+ # variant="fp16",
+ trust_remote_code=True,
+ **load_args
+ )
+ else:
+ pipe = pipln.from_single_file(
+ model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ load_safety_checker=False,
+ requires_safety_checker=False,
+ torch_dtype=self.torch_dtype,
+ safety_checker=None,
+ trust_remote_code=True,
+ **load_args
+ )
+ flush()
+
+ pipe.register_to_config(requires_safety_checker=False)
+ text_encoder = pipe.text_encoder
+ text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ tokenizer = pipe.tokenizer
+
+ # scheduler doesn't get set sometimes, so we set it here
+ pipe.scheduler = self.noise_scheduler
+
+ # add hacks to unet to help training
+ # pipe.unet = prepare_unet_for_training(pipe.unet)
+
+ if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
+ # pixart and sd3 dont use a unet
+ self.unet = pipe.transformer
+ else:
+ self.unet: 'UNet2DConditionModel' = pipe.unet
+ self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
+ self.vae.eval()
+ self.vae.requires_grad_(False)
+ VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+ self.vae_scale_factor = VAE_SCALE_FACTOR
+ self.unet.to(self.device_torch, dtype=dtype)
+ self.unet.requires_grad_(False)
+ self.unet.eval()
+
+ # load any loras we have
+ if self.model_config.lora_path is not None and not self.is_flux:
+ pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
+ pipe.fuse_lora()
+ # unfortunately, not an easier way with peft
+ pipe.unload_lora_weights()
+
+ self.tokenizer = tokenizer
+ self.text_encoder = text_encoder
+ self.pipeline = pipe
+ self.load_refiner()
+ self.is_loaded = True
+
+ if self.model_config.assistant_lora_path is not None:
+ print("Loading assistant lora")
+ self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
+ self.model_config.assistant_lora_path, self)
+
+ if self.invert_assistant_lora:
+ # invert and disable during training
+ self.assistant_lora.multiplier = -1.0
+ self.assistant_lora.is_active = False
+
+ if self.model_config.inference_lora_path is not None:
+ print("Loading inference lora")
+ self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
+ self.model_config.inference_lora_path, self)
+ # disable during training
+ self.assistant_lora.is_active = False
+
+ if self.is_pixart and self.vae_scale_factor == 16:
+ # TODO make our own pipeline?
+ # we generate an image 2x larger, so we need to copy the sizes from larger ones down
+ # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
+ for key in ASPECT_RATIO_256_BIN.keys():
+ ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2]
+ for key in ASPECT_RATIO_512_BIN.keys():
+ ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2]
+ for key in ASPECT_RATIO_1024_BIN.keys():
+ ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2]
+ for key in ASPECT_RATIO_2048_BIN.keys():
+ ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2]
+
+ def te_train(self):
+ if isinstance(self.text_encoder, list):
+ for te in self.text_encoder:
+ te.train()
+ else:
+ self.text_encoder.train()
+
+ def te_eval(self):
+ if isinstance(self.text_encoder, list):
+ for te in self.text_encoder:
+ te.eval()
+ else:
+ self.text_encoder.eval()
+
+ def load_refiner(self):
+ # for now, we are just going to rely on the TE from the base model
+ # which is TE2 for SDXL and TE for SD (no refiner currently)
+ # and completely ignore a TE that may or may not be packaged with the refiner
+ if self.model_config.refiner_name_or_path is not None:
+ refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
+ # load the refiner model
+ dtype = get_torch_dtype(self.dtype)
+ model_path = self.model_config.refiner_name_or_path
+ if not os.path.exists(model_path) or os.path.isdir(model_path):
+ # TODO only load unet??
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ # variant="fp16",
+ use_safetensors=True,
+ ).to(self.device_torch)
+ else:
+ refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
+ model_path,
+ dtype=dtype,
+ device=self.device_torch,
+ torch_dtype=self.torch_dtype,
+ original_config_file=refiner_config_path,
+ ).to(self.device_torch)
+
+ self.refiner_unet = refiner.unet
+ del refiner
+ flush()
+
+ @torch.no_grad()
+ def generate_images(
+ self,
+ image_configs: List[GenerateImageConfig],
+ sampler=None,
+ pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
+ ):
+ merge_multiplier = 1.0
+ flush()
+ # if using assistant, unfuse it
+ if self.model_config.assistant_lora_path is not None:
+ print("Unloading assistant lora")
+ if self.invert_assistant_lora:
+ self.assistant_lora.is_active = True
+ # move weights on to the device
+ self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
+ else:
+ self.assistant_lora.is_active = False
+
+ if self.model_config.inference_lora_path is not None:
+ print("Loading inference lora")
+ self.assistant_lora.is_active = True
+ # move weights on to the device
+ self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
+
+ if self.network is not None:
+ self.network.eval()
+ network = self.network
+ # check if we have the same network weight for all samples. If we do, we can merge in th
+ # the network to drastically speed up inference
+ unique_network_weights = set([x.network_multiplier for x in image_configs])
+ if len(unique_network_weights) == 1 and self.network.can_merge_in:
+ can_merge_in = True
+ merge_multiplier = unique_network_weights.pop()
+ network.merge_in(merge_weight=merge_multiplier)
+ else:
+ network = BlankNetwork()
+
+ self.save_device_state()
+ self.set_device_state_preset('generate')
+
+ # save current seed state for training
+ rng_state = torch.get_rng_state()
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
+
+ if pipeline is None:
+ noise_scheduler = self.noise_scheduler
+ if sampler is not None:
+ if sampler.startswith("sample_"): # sample_dpmpp_2m
+ # using ksampler
+ noise_scheduler = get_sampler(
+ 'lms', {
+ "prediction_type": self.prediction_type,
+ })
+ else:
+ noise_scheduler = get_sampler(
+ sampler,
+ {
+ "prediction_type": self.prediction_type,
+ },
+ 'sd' if not self.is_pixart else 'pixart'
+ )
+
+ try:
+ noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
+ except:
+ pass
+
+ if sampler.startswith("sample_") and self.is_xl:
+ # using kdiffusion
+ Pipe = StableDiffusionKDiffusionXLPipeline
+ elif self.is_xl:
+ Pipe = StableDiffusionXLPipeline
+ elif self.is_v3:
+ Pipe = StableDiffusion3Pipeline
+ else:
+ Pipe = StableDiffusionPipeline
+
+ extra_args = {}
+ if self.adapter is not None:
+ if isinstance(self.adapter, T2IAdapter):
+ if self.is_xl:
+ Pipe = StableDiffusionXLAdapterPipeline
+ else:
+ Pipe = StableDiffusionAdapterPipeline
+ extra_args['adapter'] = self.adapter
+ elif isinstance(self.adapter, ControlNetModel):
+ if self.is_xl:
+ Pipe = StableDiffusionXLControlNetPipeline
+ else:
+ Pipe = StableDiffusionControlNetPipeline
+ extra_args['controlnet'] = self.adapter
+ elif isinstance(self.adapter, ReferenceAdapter):
+ # pass the noise scheduler to the adapter
+ self.adapter.noise_scheduler = noise_scheduler
+ else:
+ if self.is_xl:
+ extra_args['add_watermarker'] = False
+
+ # TODO add clip skip
+ if self.is_xl:
+ pipeline = Pipe(
+ vae=self.vae,
+ unet=self.unet,
+ text_encoder=self.text_encoder[0],
+ text_encoder_2=self.text_encoder[1],
+ tokenizer=self.tokenizer[0],
+ tokenizer_2=self.tokenizer[1],
+ scheduler=noise_scheduler,
+ **extra_args
+ ).to(self.device_torch)
+ pipeline.watermark = None
+ elif self.is_flux:
+ if self.model_config.use_flux_cfg:
+ pipeline = FluxWithCFGPipeline(
+ vae=self.vae,
+ transformer=self.unet,
+ text_encoder=self.text_encoder[0],
+ text_encoder_2=self.text_encoder[1],
+ tokenizer=self.tokenizer[0],
+ tokenizer_2=self.tokenizer[1],
+ scheduler=noise_scheduler,
+ **extra_args
+ )
+
+ else:
+ pipeline = FluxPipeline(
+ vae=self.vae,
+ transformer=self.unet,
+ text_encoder=self.text_encoder[0],
+ text_encoder_2=self.text_encoder[1],
+ tokenizer=self.tokenizer[0],
+ tokenizer_2=self.tokenizer[1],
+ scheduler=noise_scheduler,
+ **extra_args
+ )
+ pipeline.watermark = None
+ elif self.is_v3:
+ pipeline = Pipe(
+ vae=self.vae,
+ transformer=self.unet,
+ text_encoder=self.text_encoder[0],
+ text_encoder_2=self.text_encoder[1],
+ text_encoder_3=self.text_encoder[2],
+ tokenizer=self.tokenizer[0],
+ tokenizer_2=self.tokenizer[1],
+ tokenizer_3=self.tokenizer[2],
+ scheduler=noise_scheduler,
+ **extra_args
+ )
+ elif self.is_pixart:
+ pipeline = PixArtSigmaPipeline(
+ vae=self.vae,
+ transformer=self.unet,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ scheduler=noise_scheduler,
+ **extra_args
+ )
+
+ elif self.is_auraflow:
+ pipeline = AuraFlowPipeline(
+ vae=self.vae,
+ transformer=self.unet,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ scheduler=noise_scheduler,
+ **extra_args
+ )
+
+ else:
+ pipeline = Pipe(
+ vae=self.vae,
+ unet=self.unet,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ scheduler=noise_scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=False,
+ **extra_args
+ )
+ flush()
+ # disable progress bar
+ pipeline.set_progress_bar_config(disable=True)
+
+ if sampler.startswith("sample_"):
+ pipeline.set_scheduler(sampler)
+
+ refiner_pipeline = None
+ if self.refiner_unet:
+ # build refiner pipeline
+ refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
+ vae=pipeline.vae,
+ unet=self.refiner_unet,
+ text_encoder=None,
+ text_encoder_2=pipeline.text_encoder_2,
+ tokenizer=None,
+ tokenizer_2=pipeline.tokenizer_2,
+ scheduler=pipeline.scheduler,
+ add_watermarker=False,
+ requires_aesthetics_score=True,
+ ).to(self.device_torch)
+ # refiner_pipeline.register_to_config(requires_aesthetics_score=False)
+ refiner_pipeline.watermark = None
+ refiner_pipeline.set_progress_bar_config(disable=True)
+ flush()
+
+ start_multiplier = 1.0
+ if self.network is not None:
+ start_multiplier = self.network.multiplier
+
+ # pipeline.to(self.device_torch)
+
+ with network:
+ with torch.no_grad():
+ if self.network is not None:
+ assert self.network.is_active
+
+ for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
+ gen_config = image_configs[i]
+
+ extra = {}
+ validation_image = None
+ if self.adapter is not None and gen_config.adapter_image_path is not None:
+ validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
+ if isinstance(self.adapter, T2IAdapter):
+ # not sure why this is double??
+ validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
+ extra['image'] = validation_image
+ extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
+ if isinstance(self.adapter, ControlNetModel):
+ validation_image = validation_image.resize((gen_config.width, gen_config.height))
+ extra['image'] = validation_image
+ extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
+ if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ validation_image = transform(validation_image)
+ if isinstance(self.adapter, CustomAdapter):
+ # todo allow loading multiple
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+ validation_image = transform(validation_image)
+ self.adapter.num_images = 1
+ if isinstance(self.adapter, ReferenceAdapter):
+ # need -1 to 1
+ validation_image = transforms.ToTensor()(validation_image)
+ validation_image = validation_image * 2.0 - 1.0
+ validation_image = validation_image.unsqueeze(0)
+ self.adapter.set_reference_images(validation_image)
+
+ if self.network is not None:
+ self.network.multiplier = gen_config.network_multiplier
+ torch.manual_seed(gen_config.seed)
+ torch.cuda.manual_seed(gen_config.seed)
+
+ generator = torch.manual_seed(gen_config.seed)
+
+ if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
+ and gen_config.adapter_image_path is not None:
+ # run through the adapter to saturate the embeds
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
+ self.adapter(conditional_clip_embeds)
+
+ if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
+ # handle condition the prompts
+ gen_config.prompt = self.adapter.condition_prompt(
+ gen_config.prompt,
+ is_unconditional=False,
+ )
+ gen_config.prompt_2 = gen_config.prompt
+ gen_config.negative_prompt = self.adapter.condition_prompt(
+ gen_config.negative_prompt,
+ is_unconditional=True,
+ )
+ gen_config.negative_prompt_2 = gen_config.negative_prompt
+
+ if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
+ self.adapter.trigger_pre_te(
+ tensors_0_1=validation_image,
+ is_training=False,
+ has_been_preprocessed=False,
+ quad_count=4
+ )
+
+ # encode the prompt ourselves so we can do fun stuff with embeddings
+ if isinstance(self.adapter, CustomAdapter):
+ self.adapter.is_unconditional_run = False
+ conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
+
+ if isinstance(self.adapter, CustomAdapter):
+ self.adapter.is_unconditional_run = True
+ unconditional_embeds = self.encode_prompt(
+ gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
+ )
+ if isinstance(self.adapter, CustomAdapter):
+ self.adapter.is_unconditional_run = False
+
+ # allow any manipulations to take place to embeddings
+ gen_config.post_process_embeddings(
+ conditional_embeds,
+ unconditional_embeds,
+ )
+
+ if self.decorator is not None:
+ # apply the decorator to the embeddings
+ conditional_embeds.text_embeds = self.decorator(conditional_embeds.text_embeds)
+ unconditional_embeds.text_embeds = self.decorator(unconditional_embeds.text_embeds, is_unconditional=True)
+
+ if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
+ and gen_config.adapter_image_path is not None:
+ # apply the image projection
+ conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
+ unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
+ True)
+ conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
+ unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
+
+ if self.adapter is not None and isinstance(self.adapter,
+ CustomAdapter) and validation_image is not None:
+ conditional_embeds = self.adapter.condition_encoded_embeds(
+ tensors_0_1=validation_image,
+ prompt_embeds=conditional_embeds,
+ is_training=False,
+ has_been_preprocessed=False,
+ is_generating_samples=True,
+ )
+ unconditional_embeds = self.adapter.condition_encoded_embeds(
+ tensors_0_1=validation_image,
+ prompt_embeds=unconditional_embeds,
+ is_training=False,
+ has_been_preprocessed=False,
+ is_unconditional=True,
+ is_generating_samples=True,
+ )
+
+ if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(
+ gen_config.extra_values) > 0:
+ extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch,
+ dtype=self.torch_dtype)
+ # apply extra values to the embeddings
+ self.adapter.add_extra_values(extra_values, is_unconditional=False)
+ self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
+ pass # todo remove, for debugging
+
+ if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
+ # if we have a refiner loaded, set the denoising end at the refiner start
+ extra['denoising_end'] = gen_config.refiner_start_at
+ extra['output_type'] = 'latent'
+ if not self.is_xl:
+ raise ValueError("Refiner is only supported for XL models")
+
+ conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
+ unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
+
+ if self.is_xl:
+ # fix guidance rescale for sdxl
+ # was trained on 0.7 (I believe)
+
+ grs = gen_config.guidance_rescale
+ # if grs is None or grs < 0.00001:
+ # grs = 0.7
+ # grs = 0.0
+
+ if sampler.startswith("sample_"):
+ extra['use_karras_sigmas'] = True
+ extra = {
+ **extra,
+ **gen_config.extra_kwargs,
+ }
+
+ img = pipeline(
+ # prompt=gen_config.prompt,
+ # prompt_2=gen_config.prompt_2,
+ prompt_embeds=conditional_embeds.text_embeds,
+ pooled_prompt_embeds=conditional_embeds.pooled_embeds,
+ negative_prompt_embeds=unconditional_embeds.text_embeds,
+ negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
+ # negative_prompt=gen_config.negative_prompt,
+ # negative_prompt_2=gen_config.negative_prompt_2,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ guidance_rescale=grs,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ elif self.is_v3:
+ img = pipeline(
+ prompt_embeds=conditional_embeds.text_embeds,
+ pooled_prompt_embeds=conditional_embeds.pooled_embeds,
+ negative_prompt_embeds=unconditional_embeds.text_embeds,
+ negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ elif self.is_flux:
+ if self.model_config.use_flux_cfg:
+ img = pipeline(
+ prompt_embeds=conditional_embeds.text_embeds,
+ pooled_prompt_embeds=conditional_embeds.pooled_embeds,
+ negative_prompt_embeds=unconditional_embeds.text_embeds,
+ negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ else:
+ img = pipeline(
+ prompt_embeds=conditional_embeds.text_embeds,
+ pooled_prompt_embeds=conditional_embeds.pooled_embeds,
+ # negative_prompt_embeds=unconditional_embeds.text_embeds,
+ # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ elif self.is_pixart:
+ # needs attention masks for some reason
+ img = pipeline(
+ prompt=None,
+ prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
+ prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt=None,
+ # negative_prompt=gen_config.negative_prompt,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ elif self.is_auraflow:
+ pipeline: AuraFlowPipeline = pipeline
+
+ img = pipeline(
+ prompt=None,
+ prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
+ prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
+ dtype=self.unet.dtype),
+ negative_prompt=None,
+ # negative_prompt=gen_config.negative_prompt,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+ else:
+ img = pipeline(
+ # prompt=gen_config.prompt,
+ prompt_embeds=conditional_embeds.text_embeds,
+ negative_prompt_embeds=unconditional_embeds.text_embeds,
+ # negative_prompt=gen_config.negative_prompt,
+ height=gen_config.height,
+ width=gen_config.width,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ latents=gen_config.latents,
+ generator=generator,
+ **extra
+ ).images[0]
+
+ if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
+ # slide off just the last 1280 on the last dim as refiner does not use first text encoder
+ # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
+ refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
+ refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
+ # run through refiner
+ img = refiner_pipeline(
+ # prompt=gen_config.prompt,
+ # prompt_2=gen_config.prompt_2,
+
+ # slice these as it does not use both text encoders
+ # height=gen_config.height,
+ # width=gen_config.width,
+ prompt_embeds=refiner_text_embeds,
+ pooled_prompt_embeds=conditional_embeds.pooled_embeds,
+ negative_prompt_embeds=refiner_unconditional_text_embeds,
+ negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
+ num_inference_steps=gen_config.num_inference_steps,
+ guidance_scale=gen_config.guidance_scale,
+ guidance_rescale=grs,
+ denoising_start=gen_config.refiner_start_at,
+ denoising_end=gen_config.num_inference_steps,
+ image=img.unsqueeze(0),
+ generator=generator,
+ ).images[0]
+
+ gen_config.save_image(img, i)
+ gen_config.log_image(img, i)
+ flush()
+
+ if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
+ self.adapter.clear_memory()
+
+ # clear pipeline and cache to reduce vram usage
+ del pipeline
+ if refiner_pipeline is not None:
+ del refiner_pipeline
+ torch.cuda.empty_cache()
+
+ # restore training state
+ torch.set_rng_state(rng_state)
+ if cuda_rng_state is not None:
+ torch.cuda.set_rng_state(cuda_rng_state)
+
+ self.restore_device_state()
+ if self.network is not None:
+ self.network.train()
+ self.network.multiplier = start_multiplier
+
+ self.unet.to(self.device_torch, dtype=self.torch_dtype)
+ if network.is_merged_in:
+ network.merge_out(merge_multiplier)
+ # self.tokenizer.to(original_device_dict['tokenizer'])
+
+ # refuse loras
+ if self.model_config.assistant_lora_path is not None:
+ print("Loading assistant lora")
+ if self.invert_assistant_lora:
+ self.assistant_lora.is_active = False
+ # move weights off the device
+ self.assistant_lora.force_to('cpu', self.torch_dtype)
+ else:
+ self.assistant_lora.is_active = True
+
+ if self.model_config.inference_lora_path is not None:
+ print("Unloading inference lora")
+ self.assistant_lora.is_active = False
+ # move weights off the device
+ self.assistant_lora.force_to('cpu', self.torch_dtype)
+
+ flush()
+
+ def get_latent_noise(
+ self,
+ height=None,
+ width=None,
+ pixel_height=None,
+ pixel_width=None,
+ batch_size=1,
+ noise_offset=0.0,
+ ):
+ VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+ if height is None and pixel_height is None:
+ raise ValueError("height or pixel_height must be specified")
+ if width is None and pixel_width is None:
+ raise ValueError("width or pixel_width must be specified")
+ if height is None:
+ height = pixel_height // VAE_SCALE_FACTOR
+ if width is None:
+ width = pixel_width // VAE_SCALE_FACTOR
+
+ num_channels = self.unet.config['in_channels']
+ if self.is_flux:
+ # has 64 channels in for some reason
+ num_channels = 16
+ noise = torch.randn(
+ (
+ batch_size,
+ num_channels,
+ height,
+ width,
+ ),
+ device=self.unet.device,
+ )
+ noise = apply_noise_offset(noise, noise_offset)
+ return noise
+
+ def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
+ VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+ if self.is_xl:
+ bs, ch, h, w = list(latents.shape)
+
+ height = h * VAE_SCALE_FACTOR
+ width = w * VAE_SCALE_FACTOR
+
+ dtype = latents.dtype
+ # just do it without any cropping nonsense
+ target_size = (height, width)
+ original_size = (height, width)
+ crops_coords_top_left = (0, 0)
+ if requires_aesthetic_score:
+ # refiner
+ # https://huggingface.co/papers/2307.01952
+ aesthetic_score = 6.0 # simulate one
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
+
+ batch_time_ids = torch.cat(
+ [add_time_ids for _ in range(bs)]
+ )
+ return batch_time_ids
+ else:
+ return None
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
+ noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
+ timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
+
+ if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
+ timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
+
+ noisy_latents_chunks = []
+
+ for idx in range(original_samples.shape[0]):
+ noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
+ timesteps_chunks[idx])
+ noisy_latents_chunks.append(noisy_latents)
+
+ noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
+ return noisy_latents
+
+ def predict_noise(
+ self,
+ latents: torch.Tensor,
+ text_embeddings: Union[PromptEmbeds, None] = None,
+ timestep: Union[int, torch.Tensor] = 1,
+ guidance_scale=7.5,
+ guidance_rescale=0,
+ add_time_ids=None,
+ conditional_embeddings: Union[PromptEmbeds, None] = None,
+ unconditional_embeddings: Union[PromptEmbeds, None] = None,
+ is_input_scaled=False,
+ detach_unconditional=False,
+ rescale_cfg=None,
+ return_conditional_pred=False,
+ guidance_embedding_scale=1.0,
+ bypass_guidance_embedding=False,
+ **kwargs,
+ ):
+ conditional_pred = None
+ # get the embeddings
+ if text_embeddings is None and conditional_embeddings is None:
+ raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
+ if text_embeddings is None and unconditional_embeddings is not None:
+ text_embeddings = concat_prompt_embeds([
+ unconditional_embeddings, # negative embedding
+ conditional_embeddings, # positive embedding
+ ])
+ elif text_embeddings is None and conditional_embeddings is not None:
+ # not doing cfg
+ text_embeddings = conditional_embeddings
+
+ # CFG is comparing neg and positive, if we have concatenated embeddings
+ # then we are doing it, otherwise we are not and takes half the time.
+ do_classifier_free_guidance = True
+
+ # check if batch size of embeddings matches batch size of latents
+ if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
+ do_classifier_free_guidance = False
+ elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
+ raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
+ latents = latents.to(self.device_torch)
+ text_embeddings = text_embeddings.to(self.device_torch)
+ timestep = timestep.to(self.device_torch)
+
+ # if timestep is zero dim, unsqueeze it
+ if len(timestep.shape) == 0:
+ timestep = timestep.unsqueeze(0)
+
+ # if we only have 1 timestep, we can just use the same timestep for all
+ if timestep.shape[0] == 1 and latents.shape[0] > 1:
+ # check if it is rank 1 or 2
+ if len(timestep.shape) == 1:
+ timestep = timestep.repeat(latents.shape[0])
+ else:
+ timestep = timestep.repeat(latents.shape[0], 0)
+
+ # handle t2i adapters
+ if 'down_intrablock_additional_residuals' in kwargs:
+ # go through each item and concat if doing cfg and it doesnt have the same shape
+ for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
+ if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
+ kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
+
+ # handle controlnet
+ if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
+ # go through each item and concat if doing cfg and it doesnt have the same shape
+ for idx, item in enumerate(kwargs['down_block_additional_residuals']):
+ if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
+ kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
+ for idx, item in enumerate(kwargs['mid_block_additional_residual']):
+ if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
+ kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
+
+ def scale_model_input(model_input, timestep_tensor):
+ if is_input_scaled:
+ return model_input
+ mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
+ timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
+ out_chunks = []
+ # unsqueeze if timestep is zero dim
+ for idx in range(model_input.shape[0]):
+ # if scheduler has step_index
+ if hasattr(self.noise_scheduler, '_step_index'):
+ self.noise_scheduler._step_index = None
+ out_chunks.append(
+ self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx])
+ )
+ return torch.cat(out_chunks, dim=0)
+
+ if self.is_xl:
+ with torch.no_grad():
+ # 16, 6 for bs of 4
+ if add_time_ids is None:
+ add_time_ids = self.get_time_ids_from_latents(latents)
+
+ if do_classifier_free_guidance:
+ # todo check this with larget batches
+ add_time_ids = torch.cat([add_time_ids] * 2)
+
+ if do_classifier_free_guidance:
+ latent_model_input = torch.cat([latents] * 2)
+ timestep = torch.cat([timestep] * 2)
+ else:
+ latent_model_input = latents
+
+ latent_model_input = scale_model_input(latent_model_input, timestep)
+
+ added_cond_kwargs = {
+ # todo can we zero here the second text encoder? or match a blank string?
+ "text_embeds": text_embeddings.pooled_embeds,
+ "time_ids": add_time_ids,
+ }
+
+ if self.model_config.refiner_name_or_path is not None:
+ # we have the refiner on the second half of everything. Do Both
+ if do_classifier_free_guidance:
+ raise ValueError("Refiner is not supported with classifier free guidance")
+
+ if self.unet.training:
+ input_chunks = torch.chunk(latent_model_input, 2, dim=0)
+ timestep_chunks = torch.chunk(timestep, 2, dim=0)
+ added_cond_kwargs_chunked = {
+ "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
+ "time_ids": torch.chunk(add_time_ids, 2, dim=0),
+ }
+ text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
+
+ # predict the noise residual
+ base_pred = self.unet(
+ input_chunks[0],
+ timestep_chunks[0],
+ encoder_hidden_states=text_embeds_chunks[0],
+ added_cond_kwargs={
+ "text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
+ "time_ids": added_cond_kwargs_chunked['time_ids'][0],
+ },
+ **kwargs,
+ ).sample
+
+ refiner_pred = self.refiner_unet(
+ input_chunks[1],
+ timestep_chunks[1],
+ encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
+ # just use the first second text encoder
+ added_cond_kwargs={
+ "text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
+ # "time_ids": added_cond_kwargs_chunked['time_ids'][1],
+ "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
+ },
+ **kwargs,
+ ).sample
+
+ noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
+ else:
+ noise_pred = self.refiner_unet(
+ latent_model_input,
+ timestep,
+ encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
+ # just use the first second text encoder
+ added_cond_kwargs={
+ "text_embeds": text_embeddings.pooled_embeds,
+ "time_ids": self.get_time_ids_from_latents(latent_model_input,
+ requires_aesthetic_score=True),
+ },
+ **kwargs,
+ ).sample
+
+ else:
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input.to(self.device_torch, self.torch_dtype),
+ timestep,
+ encoder_hidden_states=text_embeddings.text_embeds,
+ added_cond_kwargs=added_cond_kwargs,
+ **kwargs,
+ ).sample
+
+ conditional_pred = noise_pred
+
+ if do_classifier_free_guidance:
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ conditional_pred = noise_pred_text
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
+ if guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ else:
+ with torch.no_grad():
+ if do_classifier_free_guidance:
+ # if we are doing classifier free guidance, need to double up
+ latent_model_input = torch.cat([latents] * 2, dim=0)
+ timestep = torch.cat([timestep] * 2)
+ else:
+ latent_model_input = latents
+
+ latent_model_input = scale_model_input(latent_model_input, timestep)
+
+ # check if we need to concat timesteps
+ if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
+ ts_bs = timestep.shape[0]
+ if ts_bs != latent_model_input.shape[0]:
+ if ts_bs == 1:
+ timestep = torch.cat([timestep] * latent_model_input.shape[0])
+ elif ts_bs * 2 == latent_model_input.shape[0]:
+ timestep = torch.cat([timestep] * 2, dim=0)
+ else:
+ raise ValueError(
+ f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
+
+ # predict the noise residual
+ if self.is_pixart:
+ VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+ batch_size, ch, h, w = list(latents.shape)
+
+ height = h * VAE_SCALE_FACTOR
+ width = w * VAE_SCALE_FACTOR
+
+ if self.pipeline.transformer.config.sample_size == 256:
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
+ elif self.pipeline.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.pipeline.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ elif self.pipeline.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
+ orig_height, orig_width = height, width
+ height, width = self.pipeline.image_processor.classify_height_width_bin(height, width,
+ ratios=aspect_ratio_bin)
+
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.unet.config.sample_size == 128 or (
+ self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
+ resolution = torch.tensor([height, width]).repeat(batch_size, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
+ resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
+ aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ noise_pred = self.unet(
+ latent_model_input.to(self.device_torch, self.torch_dtype),
+ encoder_hidden_states=text_embeddings.text_embeds,
+ encoder_attention_mask=text_embeddings.attention_mask,
+ timestep=timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ **kwargs
+ )[0]
+
+ # learned sigma
+ if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+ else:
+ if self.unet.device != self.device_torch:
+ self.unet.to(self.device_torch)
+ if self.unet.dtype != self.torch_dtype:
+ self.unet = self.unet.to(dtype=self.torch_dtype)
+ if self.is_flux:
+ with torch.no_grad():
+
+ bs, c, h, w = latent_model_input.shape
+ latent_model_input_packed = rearrange(
+ latent_model_input,
+ "b c (h ph) (w pw) -> b (h w) (c ph pw)",
+ ph=2,
+ pw=2
+ )
+
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch)
+
+ txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
+
+ # # handle guidance
+ if self.unet.config.guidance_embeds:
+ if isinstance(guidance_embedding_scale, list):
+ guidance = torch.tensor(guidance_embedding_scale, device=self.device_torch)
+ else:
+ guidance = torch.tensor([guidance_embedding_scale], device=self.device_torch)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if bypass_guidance_embedding:
+ bypass_flux_guidance(self.unet)
+
+ cast_dtype = self.unet.dtype
+ # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
+ noise_pred = self.unet(
+ hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ # todo make sure this doesnt change
+ timestep=timestep / 1000, # timestep is 1000 scale
+ encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype),
+ # [1, 512, 4096]
+ pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
+ txt_ids=txt_ids, # [1, 512, 3]
+ img_ids=img_ids, # [1, 4096, 3]
+ guidance=guidance,
+ return_dict=False,
+ **kwargs,
+ )[0]
+
+ if isinstance(noise_pred, QTensor):
+ noise_pred = noise_pred.dequantize()
+
+ noise_pred = rearrange(
+ noise_pred,
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ h=latent_model_input.shape[2] // 2,
+ w=latent_model_input.shape[3] // 2,
+ ph=2,
+ pw=2,
+ c=latent_model_input.shape[1],
+ )
+
+ if bypass_guidance_embedding:
+ restore_flux_guidance(self.unet)
+ elif self.is_v3:
+ noise_pred = self.unet(
+ hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
+ timestep=timestep,
+ encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
+ pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
+ **kwargs,
+ ).sample
+ if isinstance(noise_pred, QTensor):
+ noise_pred = noise_pred.dequantize()
+ elif self.is_auraflow:
+ # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0])
+ t = t.to(self.device_torch, self.torch_dtype)
+
+ noise_pred = self.unet(
+ latent_model_input,
+ encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
+ timestep=t,
+ return_dict=False,
+ )[0]
+ else:
+ noise_pred = self.unet(
+ latent_model_input.to(self.device_torch, self.torch_dtype),
+ timestep=timestep,
+ encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
+ **kwargs,
+ ).sample
+
+ conditional_pred = noise_pred
+
+ if do_classifier_free_guidance:
+ # perform guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
+ conditional_pred = noise_pred_text
+ if detach_unconditional:
+ noise_pred_uncond = noise_pred_uncond.detach()
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+ if rescale_cfg is not None and rescale_cfg != guidance_scale:
+ with torch.no_grad():
+ # do cfg at the target rescale so we can match it
+ target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
+ noise_pred_text - noise_pred_uncond
+ )
+ target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
+ target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
+
+ pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach()
+ pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach()
+
+ # match the mean and std
+ noise_pred = (noise_pred - pred_mean) / pred_std
+ noise_pred = (noise_pred * target_std) + target_mean
+
+ # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
+ if guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ if return_conditional_pred:
+ return noise_pred, conditional_pred
+ return noise_pred
+
+ def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
+ if noise_scheduler is None:
+ noise_scheduler = self.noise_scheduler
+ # // sometimes they are on the wrong device, no idea why
+ if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler):
+ try:
+ noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch)
+ noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch)
+ noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch)
+ except Exception as e:
+ pass
+
+ mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
+ latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
+ timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
+ out_chunks = []
+ if len(timestep_chunks) == 1 and len(mi_chunks) > 1:
+ # expand timestep to match
+ timestep_chunks = timestep_chunks * len(mi_chunks)
+
+ for idx in range(model_input.shape[0]):
+ # Reset it so it is unique for the
+ if hasattr(noise_scheduler, '_step_index'):
+ noise_scheduler._step_index = None
+ if hasattr(noise_scheduler, 'is_scale_input_called'):
+ noise_scheduler.is_scale_input_called = True
+ out_chunks.append(
+ noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
+ 0]
+ )
+ return torch.cat(out_chunks, dim=0)
+
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
+ def diffuse_some_steps(
+ self,
+ latents: torch.FloatTensor,
+ text_embeddings: PromptEmbeds,
+ total_timesteps: int = 1000,
+ start_timesteps=0,
+ guidance_scale=1,
+ add_time_ids=None,
+ bleed_ratio: float = 0.5,
+ bleed_latents: torch.FloatTensor = None,
+ is_input_scaled=False,
+ return_first_prediction=False,
+ **kwargs,
+ ):
+ timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
+
+ first_prediction = None
+
+ for timestep in tqdm(timesteps_to_run, leave=False):
+ timestep = timestep.unsqueeze_(0)
+ noise_pred, conditional_pred = self.predict_noise(
+ latents,
+ text_embeddings,
+ timestep,
+ guidance_scale=guidance_scale,
+ add_time_ids=add_time_ids,
+ is_input_scaled=is_input_scaled,
+ return_conditional_pred=True,
+ **kwargs,
+ )
+ # some schedulers need to run separately, so do that. (euler for example)
+
+ if return_first_prediction and first_prediction is None:
+ first_prediction = conditional_pred
+
+ latents = self.step_scheduler(noise_pred, latents, timestep)
+
+ # if not last step, and bleeding, bleed in some latents
+ if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
+ latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
+
+ # only skip first scaling
+ is_input_scaled = False
+
+ # return latents_steps
+ if return_first_prediction:
+ return latents, first_prediction
+ return latents
+
+ def encode_prompt(
+ self,
+ prompt,
+ prompt2=None,
+ num_images_per_prompt=1,
+ force_all=False,
+ long_prompts=False,
+ max_length=None,
+ dropout_prob=0.0,
+ ) -> PromptEmbeds:
+ # sd1.5 embeddings are (bs, 77, 768)
+ prompt = prompt
+ # if it is not a list, make it one
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ if prompt2 is not None and not isinstance(prompt2, list):
+ prompt2 = [prompt2]
+ if self.is_xl:
+ # todo make this a config
+ # 50% chance to use an encoder anyway even if it is disabled
+ # allows the other TE to compensate for the disabled one
+ # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5
+ # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5
+ use_encoder_1 = True
+ use_encoder_2 = True
+
+ return PromptEmbeds(
+ train_tools.encode_prompts_xl(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ prompt2,
+ num_images_per_prompt=num_images_per_prompt,
+ use_text_encoder_1=use_encoder_1,
+ use_text_encoder_2=use_encoder_2,
+ truncate=not long_prompts,
+ max_length=max_length,
+ dropout_prob=dropout_prob,
+ )
+ )
+ if self.is_v3:
+ return PromptEmbeds(
+ train_tools.encode_prompts_sd3(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ truncate=not long_prompts,
+ max_length=max_length,
+ dropout_prob=dropout_prob,
+ pipeline=self.pipeline,
+ )
+ )
+ elif self.is_pixart:
+ embeds, attention_mask = train_tools.encode_prompts_pixart(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ truncate=not long_prompts,
+ max_length=300 if self.model_config.is_pixart_sigma else 120,
+ dropout_prob=dropout_prob
+ )
+ return PromptEmbeds(
+ embeds,
+ attention_mask=attention_mask,
+ )
+ elif self.is_auraflow:
+ embeds, attention_mask = train_tools.encode_prompts_auraflow(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ truncate=not long_prompts,
+ max_length=256,
+ dropout_prob=dropout_prob
+ )
+ return PromptEmbeds(
+ embeds,
+ attention_mask=attention_mask, # not used
+ )
+ elif self.is_flux:
+ prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
+ self.tokenizer, # list
+ self.text_encoder, # list
+ prompt,
+ truncate=not long_prompts,
+ max_length=512,
+ dropout_prob=dropout_prob,
+ attn_mask=self.model_config.attn_masking
+ )
+ pe = PromptEmbeds(
+ prompt_embeds
+ )
+ pe.pooled_embeds = pooled_prompt_embeds
+ return pe
+
+
+ elif isinstance(self.text_encoder, T5EncoderModel):
+ embeds, attention_mask = train_tools.encode_prompts_pixart(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ truncate=not long_prompts,
+ max_length=256,
+ dropout_prob=dropout_prob
+ )
+
+ # just mask the attention mask
+ prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape)
+ embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device)
+ return PromptEmbeds(
+ embeds,
+
+ # do we want attn mask here?
+ # attention_mask=attention_mask,
+ )
+ else:
+ return PromptEmbeds(
+ train_tools.encode_prompts(
+ self.tokenizer,
+ self.text_encoder,
+ prompt,
+ truncate=not long_prompts,
+ max_length=max_length,
+ dropout_prob=dropout_prob
+ )
+ )
+
+ @torch.no_grad()
+ def encode_images(
+ self,
+ image_list: List[torch.Tensor],
+ device=None,
+ dtype=None
+ ):
+ if device is None:
+ device = self.vae_device_torch
+ if dtype is None:
+ dtype = self.vae_torch_dtype
+
+ latent_list = []
+ # Move to vae to device if on cpu
+ if self.vae.device == 'cpu':
+ self.vae.to(device)
+ self.vae.eval()
+ self.vae.requires_grad_(False)
+ # move to device and dtype
+ image_list = [image.to(device, dtype=dtype) for image in image_list]
+
+ VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
+
+ # resize images if not divisible by 8
+ for i in range(len(image_list)):
+ image = image_list[i]
+ if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
+ image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
+ image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
+
+ images = torch.stack(image_list)
+ if isinstance(self.vae, AutoencoderTiny):
+ latents = self.vae.encode(images, return_dict=False)[0]
+ else:
+ latents = self.vae.encode(images).latent_dist.sample()
+ shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
+
+ # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303
+ # z = self.scale_factor * (z - self.shift_factor)
+ latents = self.vae.config['scaling_factor'] * (latents - shift)
+ latents = latents.to(device, dtype=dtype)
+
+ return latents
+
+ def decode_latents(
+ self,
+ latents: torch.Tensor,
+ device=None,
+ dtype=None
+ ):
+ if device is None:
+ device = self.device
+ if dtype is None:
+ dtype = self.torch_dtype
+
+ # Move to vae to device if on cpu
+ if self.vae.device == 'cpu':
+ self.vae.to(self.device)
+ latents = latents.to(device, dtype=dtype)
+ latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
+ images = self.vae.decode(latents).sample
+ images = images.to(device, dtype=dtype)
+
+ return images
+
+ def encode_image_prompt_pairs(
+ self,
+ prompt_list: List[str],
+ image_list: List[torch.Tensor],
+ device=None,
+ dtype=None
+ ):
+ # todo check image types and expand and rescale as needed
+ # device and dtype are for outputs
+ if device is None:
+ device = self.device
+ if dtype is None:
+ dtype = self.torch_dtype
+
+ embedding_list = []
+ latent_list = []
+ # embed the prompts
+ for prompt in prompt_list:
+ embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
+ embedding_list.append(embedding)
+
+ return embedding_list, latent_list
+
+ def get_weight_by_name(self, name):
+ # weights begin with te{te_num}_ for text encoder
+ # weights begin with unet_ for unet_
+ if name.startswith('te'):
+ key = name[4:]
+ # text encoder
+ te_num = int(name[2])
+ if isinstance(self.text_encoder, list):
+ return self.text_encoder[te_num].state_dict()[key]
+ else:
+ return self.text_encoder.state_dict()[key]
+ elif name.startswith('unet'):
+ key = name[5:]
+ # unet
+ return self.unet.state_dict()[key]
+
+ raise ValueError(f"Unknown weight name: {name}")
+
+ def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False):
+ return inject_trigger_into_prompt(
+ prompt,
+ trigger=trigger,
+ to_replace_list=to_replace_list,
+ add_if_not_present=add_if_not_present,
+ )
+
+ def state_dict(self, vae=True, text_encoder=True, unet=True):
+ state_dict = OrderedDict()
+ if vae:
+ for k, v in self.vae.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
+ state_dict[new_key] = v
+ if text_encoder:
+ if isinstance(self.text_encoder, list):
+ for i, encoder in enumerate(self.text_encoder):
+ for k, v in encoder.state_dict().items():
+ new_key = k if k.startswith(
+ f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
+ state_dict[new_key] = v
+ else:
+ for k, v in self.text_encoder.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
+ state_dict[new_key] = v
+ if unet:
+ for k, v in self.unet.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
+ state_dict[new_key] = v
+ return state_dict
+
+ def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
+ OrderedDict[
+ str, Parameter]:
+ named_params: OrderedDict[str, Parameter] = OrderedDict()
+ if vae:
+ for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
+ named_params[name] = param
+ if text_encoder:
+ if isinstance(self.text_encoder, list):
+ for i, encoder in enumerate(self.text_encoder):
+ if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0:
+ # dont add these params
+ continue
+ if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1:
+ # dont add these params
+ continue
+
+ for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
+ named_params[name] = param
+ else:
+ for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
+ named_params[name] = param
+ if unet:
+ if self.is_flux:
+ for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
+ named_params[name] = param
+ else:
+ for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
+ named_params[name] = param
+
+ if self.model_config.ignore_if_contains is not None:
+ # remove params that contain the ignore_if_contains from named params
+ for key in list(named_params.keys()):
+ if any([s in key for s in self.model_config.ignore_if_contains]):
+ del named_params[key]
+ if self.model_config.only_if_contains is not None:
+ # remove params that do not contain the only_if_contains from named params
+ for key in list(named_params.keys()):
+ if not any([s in key for s in self.model_config.only_if_contains]):
+ del named_params[key]
+
+ if refiner:
+ for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
+ named_params[name] = param
+
+ # convert to state dict keys, jsut replace . with _ on keys
+ if state_dict_keys:
+ new_named_params = OrderedDict()
+ for k, v in named_params.items():
+ # replace only the first . with an _
+ new_key = k.replace('.', '_', 1)
+ new_named_params[new_key] = v
+ named_params = new_named_params
+
+ return named_params
+
+ def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
+
+ # load the full refiner since we only train unet
+ if self.model_config.refiner_name_or_path is None:
+ raise ValueError("Refiner must be specified to save it")
+ refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
+ # load the refiner model
+ dtype = get_torch_dtype(self.dtype)
+ model_path = self.model_config._original_refiner_name_or_path
+ if not os.path.exists(model_path) or os.path.isdir(model_path):
+ # TODO only load unet??
+ refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device='cpu',
+ # variant="fp16",
+ use_safetensors=True,
+ )
+ else:
+ refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
+ model_path,
+ dtype=dtype,
+ device='cpu',
+ torch_dtype=self.torch_dtype,
+ original_config_file=refiner_config_path,
+ )
+ # replace original unet
+ refiner.unet = self.refiner_unet
+ flush()
+
+ diffusers_state_dict = OrderedDict()
+ for k, v in refiner.vae.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
+ diffusers_state_dict[new_key] = v
+ for k, v in refiner.text_encoder_2.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
+ diffusers_state_dict[new_key] = v
+ for k, v in refiner.unet.state_dict().items():
+ new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
+ diffusers_state_dict[new_key] = v
+
+ converted_state_dict = get_ldm_state_dict_from_diffusers(
+ diffusers_state_dict,
+ 'sdxl_refiner',
+ device='cpu',
+ dtype=save_dtype
+ )
+
+ # make sure parent folder exists
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+ save_file(converted_state_dict, output_file, metadata=meta)
+
+ if self.config_file is not None:
+ output_path_no_ext = os.path.splitext(output_file)[0]
+ output_config_path = f"{output_path_no_ext}.yaml"
+ shutil.copyfile(self.config_file, output_config_path)
+
+ def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
+ version_string = '1'
+ if self.is_v2:
+ version_string = '2'
+ if self.is_xl:
+ version_string = 'sdxl'
+ if self.is_ssd:
+ # overwrite sdxl because both wil be true here
+ version_string = 'ssd'
+ if self.is_ssd and self.is_vega:
+ version_string = 'vega'
+ # if output file does not end in .safetensors, then it is a directory and we are
+ # saving in diffusers format
+ if not output_file.endswith('.safetensors'):
+ # diffusers
+ if self.is_flux:
+ # only save the unet
+ transformer: FluxTransformer2DModel = self.unet
+ transformer.save_pretrained(
+ save_directory=os.path.join(output_file, 'transformer'),
+ safe_serialization=True,
+ )
+ else:
+
+ self.pipeline.save_pretrained(
+ save_directory=output_file,
+ safe_serialization=True,
+ )
+ # save out meta config
+ meta_path = os.path.join(output_file, 'aitk_meta.yaml')
+ with open(meta_path, 'w') as f:
+ yaml.dump(meta, f)
+
+ else:
+ save_ldm_model_from_diffusers(
+ sd=self,
+ output_file=output_file,
+ meta=meta,
+ save_dtype=save_dtype,
+ sd_version=version_string,
+ )
+ if self.config_file is not None:
+ output_path_no_ext = os.path.splitext(output_file)[0]
+ output_config_path = f"{output_path_no_ext}.yaml"
+ shutil.copyfile(self.config_file, output_config_path)
+
+ def prepare_optimizer_params(
+ self,
+ unet=False,
+ text_encoder=False,
+ text_encoder_lr=None,
+ unet_lr=None,
+ refiner_lr=None,
+ refiner=False,
+ default_lr=1e-6,
+ ):
+ # todo maybe only get locon ones?
+ # not all items are saved, to make it match, we need to match out save mappings
+ # and not train anything not mapped. Also add learning rate
+ version = 'sd1'
+ if self.is_xl:
+ version = 'sdxl'
+ if self.is_v2:
+ version = 'sd2'
+ mapping_filename = f"stable_diffusion_{version}.json"
+ mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename)
+ with open(mapping_path, 'r') as f:
+ mapping = json.load(f)
+ ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
+
+ trainable_parameters = []
+
+ # we use state dict to find params
+
+ if unet:
+ named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
+ unet_lr = unet_lr if unet_lr is not None else default_lr
+ params = []
+ if self.is_pixart or self.is_auraflow or self.is_flux:
+ for param in named_params.values():
+ if param.requires_grad:
+ params.append(param)
+ else:
+ for key, diffusers_key in ldm_diffusers_keymap.items():
+ if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
+ if named_params[diffusers_key].requires_grad:
+ params.append(named_params[diffusers_key])
+ param_data = {"params": params, "lr": unet_lr}
+ trainable_parameters.append(param_data)
+ print(f"Found {len(params)} trainable parameter in unet")
+
+ if text_encoder:
+ named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True)
+ text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr
+ params = []
+ for key, diffusers_key in ldm_diffusers_keymap.items():
+ if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
+ if named_params[diffusers_key].requires_grad:
+ params.append(named_params[diffusers_key])
+ param_data = {"params": params, "lr": text_encoder_lr}
+ trainable_parameters.append(param_data)
+
+ print(f"Found {len(params)} trainable parameter in text encoder")
+
+ if refiner:
+ named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
+ state_dict_keys=True)
+ refiner_lr = refiner_lr if refiner_lr is not None else default_lr
+ params = []
+ for key, diffusers_key in ldm_diffusers_keymap.items():
+ diffusers_key = f"refiner_{diffusers_key}"
+ if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
+ if named_params[diffusers_key].requires_grad:
+ params.append(named_params[diffusers_key])
+ param_data = {"params": params, "lr": refiner_lr}
+ trainable_parameters.append(param_data)
+
+ print(f"Found {len(params)} trainable parameter in refiner")
+
+ return trainable_parameters
+
+ def save_device_state(self):
+ # saves the current device state for all modules
+ # this is useful for when we want to alter the state and restore it
+ if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
+ unet_has_grad = self.unet.proj_out.weight.requires_grad
+ else:
+ unet_has_grad = self.unet.conv_in.weight.requires_grad
+
+ self.device_state = {
+ **empty_preset,
+ 'vae': {
+ 'training': self.vae.training,
+ 'device': self.vae.device,
+ },
+ 'unet': {
+ 'training': self.unet.training,
+ 'device': self.unet.device,
+ 'requires_grad': unet_has_grad,
+ },
+ }
+ if isinstance(self.text_encoder, list):
+ self.device_state['text_encoder']: List[dict] = []
+ for encoder in self.text_encoder:
+ try:
+ te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
+ except:
+ te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
+ self.device_state['text_encoder'].append({
+ 'training': encoder.training,
+ 'device': encoder.device,
+ # todo there has to be a better way to do this
+ 'requires_grad': te_has_grad
+ })
+ else:
+ if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
+ te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
+ else:
+ te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
+
+ self.device_state['text_encoder'] = {
+ 'training': self.text_encoder.training,
+ 'device': self.text_encoder.device,
+ 'requires_grad': te_has_grad
+ }
+ if self.adapter is not None:
+ if isinstance(self.adapter, IPAdapter):
+ requires_grad = self.adapter.image_proj_model.training
+ adapter_device = self.unet.device
+ elif isinstance(self.adapter, T2IAdapter):
+ requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
+ adapter_device = self.adapter.device
+ elif isinstance(self.adapter, ControlNetModel):
+ requires_grad = self.adapter.conv_in.training
+ adapter_device = self.adapter.device
+ elif isinstance(self.adapter, ClipVisionAdapter):
+ requires_grad = self.adapter.embedder.training
+ adapter_device = self.adapter.device
+ elif isinstance(self.adapter, CustomAdapter):
+ requires_grad = self.adapter.training
+ adapter_device = self.adapter.device
+ elif isinstance(self.adapter, ReferenceAdapter):
+ # todo update this!!
+ requires_grad = True
+ adapter_device = self.adapter.device
+ else:
+ raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
+ self.device_state['adapter'] = {
+ 'training': self.adapter.training,
+ 'device': adapter_device,
+ 'requires_grad': requires_grad,
+ }
+
+ if self.refiner_unet is not None:
+ self.device_state['refiner_unet'] = {
+ 'training': self.refiner_unet.training,
+ 'device': self.refiner_unet.device,
+ 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
+ }
+
+ def restore_device_state(self):
+ # restores the device state for all modules
+ # this is useful for when we want to alter the state and restore it
+ if self.device_state is None:
+ return
+ self.set_device_state(self.device_state)
+ self.device_state = None
+
+ def set_device_state(self, state):
+ if state['vae']['training']:
+ self.vae.train()
+ else:
+ self.vae.eval()
+ self.vae.to(state['vae']['device'])
+ if state['unet']['training']:
+ self.unet.train()
+ else:
+ self.unet.eval()
+ self.unet.to(state['unet']['device'])
+ if state['unet']['requires_grad']:
+ self.unet.requires_grad_(True)
+ else:
+ self.unet.requires_grad_(False)
+ if isinstance(self.text_encoder, list):
+ for i, encoder in enumerate(self.text_encoder):
+ if isinstance(state['text_encoder'], list):
+ if state['text_encoder'][i]['training']:
+ encoder.train()
+ else:
+ encoder.eval()
+ encoder.to(state['text_encoder'][i]['device'])
+ encoder.requires_grad_(state['text_encoder'][i]['requires_grad'])
+ else:
+ if state['text_encoder']['training']:
+ encoder.train()
+ else:
+ encoder.eval()
+ encoder.to(state['text_encoder']['device'])
+ encoder.requires_grad_(state['text_encoder']['requires_grad'])
+ else:
+ if state['text_encoder']['training']:
+ self.text_encoder.train()
+ else:
+ self.text_encoder.eval()
+ self.text_encoder.to(state['text_encoder']['device'])
+ self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
+
+ if self.adapter is not None:
+ self.adapter.to(state['adapter']['device'])
+ self.adapter.requires_grad_(state['adapter']['requires_grad'])
+ if state['adapter']['training']:
+ self.adapter.train()
+ else:
+ self.adapter.eval()
+
+ if self.refiner_unet is not None:
+ self.refiner_unet.to(state['refiner_unet']['device'])
+ self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
+ if state['refiner_unet']['training']:
+ self.refiner_unet.train()
+ else:
+ self.refiner_unet.eval()
+ flush()
+
+ def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
+ # sets a preset for device state
+
+ # save current state first
+ self.save_device_state()
+
+ active_modules = []
+ training_modules = []
+ if device_state_preset in ['cache_latents']:
+ active_modules = ['vae']
+ if device_state_preset in ['cache_clip']:
+ active_modules = ['clip']
+ if device_state_preset in ['generate']:
+ active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
+
+ state = copy.deepcopy(empty_preset)
+ # vae
+ state['vae'] = {
+ 'training': 'vae' in training_modules,
+ 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
+ 'requires_grad': 'vae' in training_modules,
+ }
+
+ # unet
+ state['unet'] = {
+ 'training': 'unet' in training_modules,
+ 'device': self.device_torch if 'unet' in active_modules else 'cpu',
+ 'requires_grad': 'unet' in training_modules,
+ }
+
+ if self.refiner_unet is not None:
+ state['refiner_unet'] = {
+ 'training': 'refiner_unet' in training_modules,
+ 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
+ 'requires_grad': 'refiner_unet' in training_modules,
+ }
+
+ # text encoder
+ if isinstance(self.text_encoder, list):
+ state['text_encoder'] = []
+ for i, encoder in enumerate(self.text_encoder):
+ state['text_encoder'].append({
+ 'training': 'text_encoder' in training_modules,
+ 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
+ 'requires_grad': 'text_encoder' in training_modules,
+ })
+ else:
+ state['text_encoder'] = {
+ 'training': 'text_encoder' in training_modules,
+ 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
+ 'requires_grad': 'text_encoder' in training_modules,
+ }
+
+ if self.adapter is not None:
+ state['adapter'] = {
+ 'training': 'adapter' in training_modules,
+ 'device': self.device_torch if 'adapter' in active_modules else 'cpu',
+ 'requires_grad': 'adapter' in training_modules,
+ }
+
+ self.set_device_state(state)
+
+ def text_encoder_to(self, *args, **kwargs):
+ if isinstance(self.text_encoder, list):
+ for encoder in self.text_encoder:
+ encoder.to(*args, **kwargs)
+ else:
+ self.text_encoder.to(*args, **kwargs)
diff --git a/toolkit/style.py b/toolkit/style.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ac33fa710b3286323357abc50b13e9bcda9aec
--- /dev/null
+++ b/toolkit/style.py
@@ -0,0 +1,232 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from torchvision import models
+
+
+# device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+def tensor_size(tensor):
+ channels = tensor.shape[1]
+ height = tensor.shape[2]
+ width = tensor.shape[3]
+ return channels * height * width
+
+class ContentLoss(nn.Module):
+
+ def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ super(ContentLoss, self).__init__()
+ self.single_target = single_target
+ self.device = device
+ self.loss = None
+
+ def forward(self, stacked_input):
+
+ if self.single_target:
+ split_size = stacked_input.size()[0] // 2
+ pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
+ else:
+ split_size = stacked_input.size()[0] // 3
+ pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0)
+
+ content_size = tensor_size(pred_layer)
+
+ # Define the separate loss function
+ def separated_loss(y_pred, y_true):
+ y_pred = y_pred.float()
+ y_true = y_true.float()
+ diff = torch.abs(y_pred - y_true)
+ l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
+ return 2. * l2 / content_size
+
+ # Calculate itemized loss
+ pred_itemized_loss = separated_loss(pred_layer, target_layer)
+ # check if is nan
+ if torch.isnan(pred_itemized_loss).any():
+ print('pred_itemized_loss is nan')
+
+ # Calculate the mean of itemized loss
+ loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
+ self.loss = loss
+
+ return stacked_input
+
+
+def convert_to_gram_matrix(inputs):
+ inputs = inputs.float()
+ shape = inputs.size()
+ batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
+ size = height * width * filters
+
+ feats = inputs.view(batch, filters, height * width)
+ feats_t = feats.transpose(1, 2)
+ grams_raw = torch.matmul(feats, feats_t)
+ gram_matrix = grams_raw / size
+
+ return gram_matrix
+
+
+######################################################################
+# Now the style loss module looks almost exactly like the content loss
+# module. The style distance is also computed using the mean square
+# error between :math:`G_{XL}` and :math:`G_{SL}`.
+#
+
+class StyleLoss(nn.Module):
+
+ def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
+ super(StyleLoss, self).__init__()
+ self.single_target = single_target
+ self.device = device
+
+ def forward(self, stacked_input):
+ input_dtype = stacked_input.dtype
+ stacked_input = stacked_input.float()
+ if self.single_target:
+ split_size = stacked_input.size()[0] // 2
+ preds, style_target = torch.split(stacked_input, split_size, dim=0)
+ else:
+ split_size = stacked_input.size()[0] // 3
+ preds, style_target, _ = torch.split(stacked_input, split_size, dim=0)
+
+ def separated_loss(y_pred, y_true):
+ gram_size = y_true.size(1) * y_true.size(2)
+ sum_axis = (1, 2)
+ diff = torch.abs(y_pred - y_true)
+ raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True)
+ return raw_loss / gram_size
+
+ target_grams = convert_to_gram_matrix(style_target)
+ pred_grams = convert_to_gram_matrix(preds)
+ itemized_loss = separated_loss(pred_grams, target_grams)
+ # check if is nan
+ if torch.isnan(itemized_loss).any():
+ print('itemized_loss is nan')
+ # reshape itemized loss to be (batch, 1, 1, 1)
+ itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
+ # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
+ loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
+ self.loss = loss.to(input_dtype).float()
+ return stacked_input.to(input_dtype)
+
+
+# create a module to normalize input image so we can easily put it in a
+# ``nn.Sequential``
+class Normalization(nn.Module):
+ def __init__(self, device, dtype=torch.float32):
+ super(Normalization, self).__init__()
+ mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
+ std = torch.tensor([0.229, 0.224, 0.225]).to(device)
+ self.dtype = dtype
+ # .view the mean and std to make them [C x 1 x 1] so that they can
+ # directly work with image Tensor of shape [B x C x H x W].
+ # B is batch size. C is number of channels. H is height and W is width.
+ self.mean = torch.tensor(mean).view(-1, 1, 1)
+ self.std = torch.tensor(std).view(-1, 1, 1)
+
+ def forward(self, stacked_input):
+ # cast to float 32 if not already # only necessary when processing gram matrix
+ # if stacked_input.dtype != torch.float32:
+ # stacked_input = stacked_input.float()
+ # remove alpha channel if it exists
+ if stacked_input.shape[1] == 4:
+ stacked_input = stacked_input[:, :3, :, :]
+ # normalize to min and max of 0 - 1
+ in_min = torch.min(stacked_input)
+ in_max = torch.max(stacked_input)
+ # norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
+ # return (norm_stacked_input - self.mean) / self.std
+ return ((stacked_input - self.mean) / self.std).to(self.dtype)
+
+
+class OutputLayer(nn.Module):
+ def __init__(self, name='output_layer'):
+ super(OutputLayer, self).__init__()
+ self.name = name
+ self.tensor = None
+
+ def forward(self, stacked_input):
+ self.tensor = stacked_input
+ return stacked_input
+
+
+def get_style_model_and_losses(
+ single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ output_layer_name=None,
+ dtype=torch.float32
+):
+ # content_layers = ['conv_4']
+ # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
+ content_layers = ['conv2_2', 'conv3_2', 'conv4_2']
+ style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
+ cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
+ # set all weights in the model to our dtype
+ # for layer in cnn.children():
+ # layer.to(dtype=dtype)
+
+ # normalization module
+ normalization = Normalization(device, dtype=dtype).to(device)
+
+ # just in order to have an iterable access to or list of content/style
+ # losses
+ content_losses = []
+ style_losses = []
+
+ # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
+ # to put in modules that are supposed to be activated sequentially
+ model = nn.Sequential(normalization)
+
+ i = 0 # increment every time we see a conv
+ block = 1
+ children = list(cnn.children())
+
+ output_layer = None
+
+ for layer in children:
+ if isinstance(layer, nn.Conv2d):
+ i += 1
+ name = f'conv{block}_{i}_raw'
+ elif isinstance(layer, nn.ReLU):
+ # name = 'relu_{}'.format(i)
+ name = f'conv{block}_{i}' # target this
+ # The in-place version doesn't play very nicely with the ``ContentLoss``
+ # and ``StyleLoss`` we insert below. So we replace with out-of-place
+ # ones here.
+ layer = nn.ReLU(inplace=False)
+ elif isinstance(layer, nn.MaxPool2d):
+ name = 'pool_{}'.format(i)
+ block += 1
+ i = 0
+ elif isinstance(layer, nn.BatchNorm2d):
+ name = 'bn_{}'.format(i)
+ else:
+ raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
+
+ model.add_module(name, layer)
+
+ if name in content_layers:
+ # add content loss:
+ content_loss = ContentLoss(single_target=single_target, device=device)
+ model.add_module("content_loss_{}_{}".format(block, i), content_loss)
+ content_losses.append(content_loss)
+
+ if name in style_layers:
+ # add style loss:
+ style_loss = StyleLoss(single_target=single_target, device=device)
+ model.add_module("style_loss_{}_{}".format(block, i), style_loss)
+ style_losses.append(style_loss)
+
+ if output_layer_name is not None and name == output_layer_name:
+ output_layer = OutputLayer(name)
+ model.add_module("output_layer_{}_{}".format(block, i), output_layer)
+
+ # now we trim off the layers after the last content and style losses
+ for i in range(len(model) - 1, -1, -1):
+ if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer):
+ break
+
+ model = model[:(i + 1)]
+ model.to(dtype=dtype)
+
+ return model, style_losses, content_losses, output_layer
diff --git a/toolkit/timer.py b/toolkit/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4fecba1321aa856808bd7a3290511882b84627
--- /dev/null
+++ b/toolkit/timer.py
@@ -0,0 +1,65 @@
+import time
+from collections import OrderedDict, deque
+
+
+class Timer:
+ def __init__(self, name='Timer', max_buffer=10):
+ self.name = name
+ self.max_buffer = max_buffer
+ self.timers = OrderedDict()
+ self.active_timers = {}
+ self.current_timer = None # Used for the context manager functionality
+
+ def start(self, timer_name):
+ if timer_name not in self.timers:
+ self.timers[timer_name] = deque(maxlen=self.max_buffer)
+ self.active_timers[timer_name] = time.time()
+
+ def cancel(self, timer_name):
+ """Cancel an active timer."""
+ if timer_name in self.active_timers:
+ del self.active_timers[timer_name]
+
+ def stop(self, timer_name):
+ if timer_name not in self.active_timers:
+ raise ValueError(f"Timer '{timer_name}' was not started!")
+
+ elapsed_time = time.time() - self.active_timers[timer_name]
+ self.timers[timer_name].append(elapsed_time)
+
+ # Clean up active timers
+ del self.active_timers[timer_name]
+
+ # Check if this timer's buffer exceeds max_buffer and remove the oldest if it does
+ if len(self.timers[timer_name]) > self.max_buffer:
+ self.timers[timer_name].popleft()
+
+ def print(self):
+ print(f"\nTimer '{self.name}':")
+ # sort by longest at top
+ for timer_name, timings in sorted(self.timers.items(), key=lambda x: sum(x[1]), reverse=True):
+ avg_time = sum(timings) / len(timings)
+ print(f" - {avg_time:.4f}s avg - {timer_name}, num = {len(timings)}")
+
+ print('')
+
+ def reset(self):
+ self.timers.clear()
+ self.active_timers.clear()
+
+ def __call__(self, timer_name):
+ """Enable the use of the Timer class as a context manager."""
+ self.current_timer = timer_name
+ self.start(timer_name)
+ return self
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_type is None:
+ # No exceptions, stop the timer normally
+ self.stop(self.current_timer)
+ else:
+ # There was an exception, cancel the timer
+ self.cancel(self.current_timer)
diff --git a/toolkit/train_pipelines.py b/toolkit/train_pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9cc623cd55e802bcad1de41cd90be6a57d2743a
--- /dev/null
+++ b/toolkit/train_pipelines.py
@@ -0,0 +1,316 @@
+from typing import Optional, Tuple, Callable, Dict, Any, Union, List
+
+import torch
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+
+from toolkit.lora_special import LoRASpecialNetwork
+from toolkit.pipelines import CustomStableDiffusionXLPipeline
+
+
+class TransferStableDiffusionXLPipeline(CustomStableDiffusionXLPipeline):
+ def transfer_diffuse(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ target_unet: Optional[torch.nn.Module] = None,
+ pre_condition_callback = None,
+ each_step_callback = None,
+ network: Optional[LoRASpecialNetwork] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 7.1 Apply denoising_end
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ conditioned_noise_pred, conditioned_latent_model_input = pre_condition_callback(
+ noise_pred.clone().detach(),
+ latent_model_input.clone().detach(),
+ )
+
+ # start grad
+ with torch.enable_grad():
+ with network:
+ assert network.is_active
+ noise_train_pred = target_unet(
+ conditioned_latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+ each_step_callback(conditioned_noise_pred, noise_train_pred)
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..592f7cc681c51e653602883436c5750cb53c8873
--- /dev/null
+++ b/toolkit/train_tools.py
@@ -0,0 +1,768 @@
+import argparse
+import hashlib
+import json
+import os
+import time
+from typing import TYPE_CHECKING, Union, List
+import sys
+
+from torch.cuda.amp import GradScaler
+
+from toolkit.paths import SD_SCRIPTS_ROOT
+
+sys.path.append(SD_SCRIPTS_ROOT)
+
+from diffusers import (
+ DDPMScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ DPMSolverSinglestepScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+ DDIMScheduler,
+ EulerDiscreteScheduler,
+ HeunDiscreteScheduler,
+ KDPM2DiscreteScheduler,
+ KDPM2AncestralDiscreteScheduler
+)
+import torch
+import re
+from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel
+
+SCHEDULER_LINEAR_START = 0.00085
+SCHEDULER_LINEAR_END = 0.0120
+SCHEDULER_TIMESTEPS = 1000
+SCHEDLER_SCHEDULE = "scaled_linear"
+
+UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
+TEXT_ENCODER_2_PROJECTION_DIM = 1280
+UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
+
+
+def get_torch_dtype(dtype_str):
+ # if it is a torch dtype, return it
+ if isinstance(dtype_str, torch.dtype):
+ return dtype_str
+ if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
+ return torch.float
+ if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
+ return torch.float16
+ if dtype_str == "bf16" or dtype_str == "bfloat16":
+ return torch.bfloat16
+ if dtype_str == "8bit" or dtype_str == "e4m3fn" or dtype_str == "float8":
+ return torch.float8_e4m3fn
+ return dtype_str
+
+
+def replace_filewords_prompt(prompt, args: argparse.Namespace):
+ # if name_replace attr in args (may not be)
+ if hasattr(args, "name_replace") and args.name_replace is not None:
+ # replace [name] to args.name_replace
+ prompt = prompt.replace("[name]", args.name_replace)
+ if hasattr(args, "prepend") and args.prepend is not None:
+ # prepend to every item in prompt file
+ prompt = args.prepend + ' ' + prompt
+ if hasattr(args, "append") and args.append is not None:
+ # append to every item in prompt file
+ prompt = prompt + ' ' + args.append
+ return prompt
+
+
+def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace):
+ # if name_replace attr in args (may not be)
+ if hasattr(args, "name_replace") and args.name_replace is not None:
+ if not len(dataset_group.image_data) > 0:
+ # throw error
+ raise ValueError("dataset_group.image_data is empty")
+ for key in dataset_group.image_data:
+ dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace(
+ "[name]", args.name_replace)
+
+ return dataset_group
+
+
+def get_seeds_from_latents(latents):
+ # latents shape = (batch_size, 4, height, width)
+ # for speed we only use 8x8 slice of the first channel
+ seeds = []
+
+ # split batch up
+ for i in range(latents.shape[0]):
+ # use only first channel, multiply by 255 and convert to int
+ tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width)
+ # slice 8x8
+ tensor = tensor[:8, :8]
+ # clip to 0-255
+ tensor = torch.clamp(tensor, 0, 255)
+ # convert to 8bit int
+ tensor = tensor.to(torch.uint8)
+ # convert to bytes
+ tensor_bytes = tensor.cpu().numpy().tobytes()
+ # hash
+ hash_object = hashlib.sha256(tensor_bytes)
+ # get hex
+ hex_dig = hash_object.hexdigest()
+ # convert to int
+ seed = int(hex_dig, 16) % (2 ** 32)
+ # append
+ seeds.append(seed)
+ return seeds
+
+
+def get_noise_from_latents(latents):
+ seed_list = get_seeds_from_latents(latents)
+ noise = []
+ for seed in seed_list:
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ noise.append(torch.randn_like(latents[0]))
+ return torch.stack(noise)
+
+
+# mix 0 is completely noise mean, mix 1 is completely target mean
+
+def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None):
+ dim = dim or (1, 2, 3)
+ # reduce mean of noise on dim 2, 3, keeping 0 and 1 intact
+ noise_mean = noise.mean(dim=dim, keepdim=True)
+ target_mean = target.mean(dim=dim, keepdim=True)
+
+ new_noise_mean = mix * target_mean + (1 - mix) * noise_mean
+
+ noise = noise - noise_mean + new_noise_mean
+ return noise
+
+
+# https://www.crosslabs.org//blog/diffusion-with-offset-noise
+def apply_noise_offset(noise, noise_offset):
+ if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
+ return noise
+ noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
+ return noise
+
+
+if TYPE_CHECKING:
+ from toolkit.stable_diffusion_model import PromptEmbeds
+
+
+def concat_prompt_embeddings(
+ unconditional: 'PromptEmbeds',
+ conditional: 'PromptEmbeds',
+ n_imgs: int,
+):
+ from toolkit.stable_diffusion_model import PromptEmbeds
+ text_embeds = torch.cat(
+ [unconditional.text_embeds, conditional.text_embeds]
+ ).repeat_interleave(n_imgs, dim=0)
+ pooled_embeds = None
+ if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
+ pooled_embeds = torch.cat(
+ [unconditional.pooled_embeds, conditional.pooled_embeds]
+ ).repeat_interleave(n_imgs, dim=0)
+ return PromptEmbeds([text_embeds, pooled_embeds])
+
+
+def addnet_hash_safetensors(b):
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def addnet_hash_legacy(b):
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+ m = hashlib.sha256()
+
+ b.seek(0x100000)
+ m.update(b.read(0x10000))
+ return m.hexdigest()[0:8]
+
+
+if TYPE_CHECKING:
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
+
+
+def text_tokenize(
+ tokenizer: 'CLIPTokenizer',
+ prompts: list[str],
+ truncate: bool = True,
+ max_length: int = None,
+ max_length_multiplier: int = 4,
+):
+ # allow fo up to 4x the max length for long prompts
+ if max_length is None:
+ if truncate:
+ max_length = tokenizer.model_max_length
+ else:
+ # allow up to 4x the max length for long prompts
+ max_length = tokenizer.model_max_length * max_length_multiplier
+
+ input_ids = tokenizer(
+ prompts,
+ padding='max_length',
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ ).input_ids
+
+ if truncate or max_length == tokenizer.model_max_length:
+ return input_ids
+ else:
+ # remove additional padding
+ num_chunks = input_ids.shape[1] // tokenizer.model_max_length
+ chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
+
+ # New list to store non-redundant chunks
+ non_redundant_chunks = []
+
+ for chunk in chunks:
+ if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
+ non_redundant_chunks.append(chunk)
+
+ input_ids = torch.cat(non_redundant_chunks, dim=1)
+ return input_ids
+
+
+# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
+def text_encode_xl(
+ text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
+ tokens: torch.FloatTensor,
+ num_images_per_prompt: int = 1,
+ max_length: int = 77, # not sure what default to put here, always pass one?
+ truncate: bool = True,
+):
+ if truncate:
+ # normal short prompt 77 tokens max
+ prompt_embeds = text_encoder(
+ tokens.to(text_encoder.device), output_hidden_states=True
+ )
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
+ else:
+ # handle long prompts
+ prompt_embeds_list = []
+ tokens = tokens.to(text_encoder.device)
+ pooled_prompt_embeds = None
+ for i in range(0, tokens.shape[-1], max_length):
+ # todo run it through the in a single batch
+ section_tokens = tokens[:, i: i + max_length]
+ embeds = text_encoder(section_tokens, output_hidden_states=True)
+ pooled_prompt_embed = embeds[0]
+ if pooled_prompt_embeds is None:
+ # we only want the first ( I think??)
+ pooled_prompt_embeds = pooled_prompt_embed
+ prompt_embed = embeds.hidden_states[-2] # always penultimate layer
+ prompt_embeds_list.append(prompt_embed)
+
+ prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def encode_prompts_xl(
+ tokenizers: list['CLIPTokenizer'],
+ text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
+ prompts: list[str],
+ prompts2: Union[list[str], None],
+ num_images_per_prompt: int = 1,
+ use_text_encoder_1: bool = True, # sdxl
+ use_text_encoder_2: bool = True, # sdxl
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+) -> tuple[torch.FloatTensor, torch.FloatTensor]:
+ # text_encoder and text_encoder_2's penuultimate layer's output
+ text_embeds_list = []
+ pooled_text_embeds = None # always text_encoder_2's pool
+ if prompts2 is None:
+ prompts2 = prompts
+
+ for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
+ # todo, we are using a blank string to ignore that encoder for now.
+ # find a better way to do this (zeroing?, removing it from the unet?)
+ prompt_list_to_use = prompts if idx == 0 else prompts2
+ if idx == 0 and not use_text_encoder_1:
+ prompt_list_to_use = ["" for _ in prompts]
+ if idx == 1 and not use_text_encoder_2:
+ prompt_list_to_use = ["" for _ in prompts]
+
+ if dropout_prob > 0.0:
+ # randomly drop out prompts
+ prompt_list_to_use = [
+ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompt_list_to_use
+ ]
+
+ text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
+ # set the max length for the next one
+ if idx == 0:
+ max_length = text_tokens_input_ids.shape[-1]
+
+ text_embeds, pooled_text_embeds = text_encode_xl(
+ text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
+ truncate=truncate
+ )
+
+ text_embeds_list.append(text_embeds)
+
+ bs_embed = pooled_text_embeds.shape[0]
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
+
+def encode_prompts_sd3(
+ tokenizers: list['CLIPTokenizer'],
+ text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]],
+ prompts: list[str],
+ num_images_per_prompt: int = 1,
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+ pipeline = None,
+):
+ text_embeds_list = []
+ pooled_text_embeds = None # always text_encoder_2's pool
+
+ prompt_2 = prompts
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ prompt_3 = prompts
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ device = text_encoders[0].device
+
+ prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds(
+ prompt=prompts,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=0,
+ )
+ prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds(
+ prompt=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ clip_skip=None,
+ clip_model_index=1,
+ )
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
+
+ t5_prompt_embed = pipeline._get_t5_prompt_embeds(
+ prompt=prompt_3,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device
+ )
+
+ clip_prompt_embeds = torch.nn.functional.pad(
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
+ )
+
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
+def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
+ if max_length is None and not truncate:
+ raise ValueError("max_length must be set if truncate is True")
+ try:
+ tokens = tokens.to(text_encoder.device)
+ except Exception as e:
+ print(e)
+ print("tokens.device", tokens.device)
+ print("text_encoder.device", text_encoder.device)
+ raise e
+
+ if truncate:
+ return text_encoder(tokens)[0]
+ else:
+ # handle long prompts
+ prompt_embeds_list = []
+ for i in range(0, tokens.shape[-1], max_length):
+ prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
+ prompt_embeds_list.append(prompt_embeds)
+
+ return torch.cat(prompt_embeds_list, dim=1)
+
+
+def encode_prompts(
+ tokenizer: 'CLIPTokenizer',
+ text_encoder: 'CLIPTextModel',
+ prompts: list[str],
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+):
+ if max_length is None:
+ max_length = tokenizer.model_max_length
+
+ if dropout_prob > 0.0:
+ # randomly drop out prompts
+ prompts = [
+ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
+ ]
+
+ text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
+ text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
+
+ return text_embeddings
+
+
+def encode_prompts_pixart(
+ tokenizer: 'T5Tokenizer',
+ text_encoder: 'T5EncoderModel',
+ prompts: list[str],
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+):
+ if max_length is None:
+ # See Section 3.1. of the paper.
+ max_length = 120
+
+ if dropout_prob > 0.0:
+ # randomly drop out prompts
+ prompts = [
+ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
+ ]
+
+ text_inputs = tokenizer(
+ prompts,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(text_encoder.device)
+
+ text_input_ids = text_input_ids.to(text_encoder.device)
+
+ prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask)
+
+ return prompt_embeds.last_hidden_state, prompt_attention_mask
+
+
+def encode_prompts_auraflow(
+ tokenizer: 'T5Tokenizer',
+ text_encoder: 'UMT5EncoderModel',
+ prompts: list[str],
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+):
+ if max_length is None:
+ max_length = 256
+
+ if dropout_prob > 0.0:
+ # randomly drop out prompts
+ prompts = [
+ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
+ ]
+
+ device = text_encoder.device
+
+ text_inputs = tokenizer(
+ prompts,
+ truncation=True,
+ max_length=max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs["input_ids"]
+ untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1])
+
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
+ prompt_embeds = text_encoder(**text_inputs)[0]
+ prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
+ prompt_embeds = prompt_embeds * prompt_attention_mask
+
+ return prompt_embeds, prompt_attention_mask
+
+def encode_prompts_flux(
+ tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']],
+ text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']],
+ prompts: list[str],
+ truncate: bool = True,
+ max_length=None,
+ dropout_prob=0.0,
+ attn_mask: bool = False,
+):
+ if max_length is None:
+ max_length = 512
+
+ if dropout_prob > 0.0:
+ # randomly drop out prompts
+ prompts = [
+ prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
+ ]
+
+ device = text_encoder[0].device
+ dtype = text_encoder[0].dtype
+
+ batch_size = len(prompts)
+
+ # clip
+ text_inputs = tokenizer[0](
+ prompts,
+ padding="max_length",
+ max_length=tokenizer[0].model_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+
+ prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ pooled_prompt_embeds = prompt_embeds.pooler_output
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
+
+ # T5
+ text_inputs = tokenizer[1](
+ prompts,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = text_encoder[1].dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ if attn_mask:
+ prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
+ prompt_embeds = prompt_embeds * prompt_attention_mask.to(dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# for XL
+def get_add_time_ids(
+ height: int,
+ width: int,
+ dynamic_crops: bool = False,
+ dtype: torch.dtype = torch.float32,
+):
+ if dynamic_crops:
+ # random float scale between 1 and 3
+ random_scale = torch.rand(1).item() * 2 + 1
+ original_size = (int(height * random_scale), int(width * random_scale))
+ # random position
+ crops_coords_top_left = (
+ torch.randint(0, original_size[0] - height, (1,)).item(),
+ torch.randint(0, original_size[1] - width, (1,)).item(),
+ )
+ target_size = (height, width)
+ else:
+ original_size = (height, width)
+ crops_coords_top_left = (0, 0)
+ target_size = (height, width)
+
+ # this is expected as 6
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ # this is expected as 2816
+ passed_add_embed_dim = (
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
+ )
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+
+def concat_embeddings(
+ unconditional: torch.FloatTensor,
+ conditional: torch.FloatTensor,
+ n_imgs: int,
+):
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
+
+
+def add_all_snr_to_noise_scheduler(noise_scheduler, device):
+ try:
+ if hasattr(noise_scheduler, "all_snr"):
+ return
+ # compute it
+ with torch.no_grad():
+ alphas_cumprod = noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ alpha = sqrt_alphas_cumprod
+ sigma = sqrt_one_minus_alphas_cumprod
+ all_snr = (alpha / sigma) ** 2
+ all_snr.requires_grad = False
+ noise_scheduler.all_snr = all_snr.to(device)
+ except Exception as e:
+ # just move on
+ pass
+
+
+def get_all_snr(noise_scheduler, device):
+ if hasattr(noise_scheduler, "all_snr"):
+ return noise_scheduler.all_snr.to(device)
+ # compute it
+ with torch.no_grad():
+ alphas_cumprod = noise_scheduler.alphas_cumprod
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ alpha = sqrt_alphas_cumprod
+ sigma = sqrt_one_minus_alphas_cumprod
+ all_snr = (alpha / sigma) ** 2
+ all_snr.requires_grad = False
+ return all_snr.to(device)
+
+class LearnableSNRGamma:
+ """
+ This is a trainer for learnable snr gamma
+ It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
+ """
+ def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
+ self.device = device
+ self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
+ self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
+ self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
+ self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
+ self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
+ self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
+ self.buffer = []
+ self.max_buffer_size = 20
+
+ def forward(self, loss, timesteps):
+ # do a our train loop for lsnr here and return our values detached
+ loss = loss.detach()
+ with torch.no_grad():
+ loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
+ for loss_chunk in loss_chunks:
+ self.buffer.append(loss_chunk.mean().detach())
+ if len(self.buffer) > self.max_buffer_size:
+ self.buffer.pop(0)
+ all_snr = get_all_snr(self.noise_scheduler, loss.device)
+ snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
+ base_snrs = snr.clone().detach()
+ snr.requires_grad = True
+ snr = (snr + self.offset_1) * self.scale + self.offset_2
+
+ gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
+ snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
+ snr_adjusted_loss = loss * snr_weight
+ with torch.no_grad():
+ target = torch.mean(torch.stack(self.buffer)).detach()
+
+ # local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
+ squared_differences = (snr_adjusted_loss - target) ** 2
+ local_loss = torch.mean(squared_differences)
+ local_loss.backward()
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
+
+
+def apply_learnable_snr_gos(
+ loss,
+ timesteps,
+ learnable_snr_trainer: LearnableSNRGamma
+):
+
+ snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
+
+ snr = (snr + offset_1) * scale + offset_2
+
+ gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
+ snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
+ snr_adjusted_loss = loss * snr_weight
+
+ return snr_adjusted_loss
+
+
+def apply_snr_weight(
+ loss,
+ timesteps,
+ noise_scheduler: Union['DDPMScheduler'],
+ gamma,
+ fixed=False,
+):
+ # will get it from noise scheduler if exist or will calculate it if not
+ all_snr = get_all_snr(noise_scheduler, loss.device)
+ # step_indices = []
+ # for t in timesteps:
+ # for i, st in enumerate(noise_scheduler.timesteps):
+ # if st == t:
+ # step_indices.append(i)
+ # break
+ # this breaks on some schedulers
+ # step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
+
+ offset = 0
+ if noise_scheduler.timesteps[0] == 1000:
+ offset = 1
+ snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps])
+ gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
+ if fixed:
+ snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
+ else:
+ snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
+ snr_adjusted_loss = loss * snr_weight
+
+ return snr_adjusted_loss
+
+
+def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler):
+ mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0)
+ mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
+ timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
+ out_chunks = []
+ # unsqueeze if timestep is zero dim
+ for idx in range(model_output.shape[0]):
+ sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim,
+ dtype=model_output.dtype, device=model_output.device)
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx]
+ out_chunks.append(out)
+ return torch.cat(out_chunks, dim=0)
diff --git a/toolkit/util/inverse_cfg.py b/toolkit/util/inverse_cfg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c85544a95c1a81cd5f7f6e4cf9ca3408e92c81e
--- /dev/null
+++ b/toolkit/util/inverse_cfg.py
@@ -0,0 +1,25 @@
+import torch
+
+
+def inverse_classifier_guidance(
+ noise_pred_cond: torch.Tensor,
+ noise_pred_uncond: torch.Tensor,
+ guidance_scale: torch.Tensor
+):
+ """
+ Adjust the noise_pred_cond for the classifier free guidance algorithm
+ to ensure that the final noise prediction equals the original noise_pred_cond.
+ """
+ # To make noise_pred equal noise_pred_cond_orig, we adjust noise_pred_cond
+ # based on the formula used in the algorithm.
+ # We derive the formula to find the correct adjustment for noise_pred_cond:
+ # noise_pred_cond = (noise_pred_cond_orig - noise_pred_uncond * guidance_scale) / (guidance_scale - 1)
+ # It's important to check if guidance_scale is not 1 to avoid division by zero.
+ if guidance_scale == 1:
+ # If guidance_scale is 1, adjusting is not needed or possible in the same way,
+ # since it would lead to division by zero. This also means the algorithm inherently
+ # doesn't alter the noise_pred_cond in relation to noise_pred_uncond.
+ # Thus, we return the original values, though this situation might need special handling.
+ return noise_pred_cond
+ adjusted_noise_pred_cond = (noise_pred_cond - noise_pred_uncond) / guidance_scale
+ return adjusted_noise_pred_cond