Spaces:
Runtime error
Runtime error
Commit
·
45732de
0
Parent(s):
Duplicate from multimodalart/LoraTheExplorer
Browse filesCo-authored-by: Apolinário from multimodal AI art <[email protected]>
- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +14 -0
- app.py +289 -0
- cog_sdxl_dataset_and_utils.py +422 -0
- custom.css +27 -0
- images/3d_style_4.jpeg +0 -0
- images/LineAni.Redmond.png +0 -0
- images/LogoRedmond-LogoLoraForSDXL.jpeg +0 -0
- images/ToyRedmond-ToyLoraForSDXL10.png +0 -0
- images/corgi_brick.jpeg +0 -0
- images/crayon.png +0 -0
- images/dog.png +0 -0
- images/embroid.png +0 -0
- images/jojoso1.jpg +0 -0
- images/josef_koudelka.webp +0 -0
- images/lego-minifig-xl.jpeg +0 -0
- images/papercut_SDXL.jpeg +0 -0
- images/pikachu.webp +0 -0
- images/pixel-art-xl.jpeg +0 -0
- images/riding-min.jpg +0 -0
- images/the_fish.jpg +0 -0
- images/uglysonic.webp +0 -0
- images/voxel-xl-lora.png +0 -0
- images/watercolor.png +0 -0
- images/william_eggleston.webp +0 -0
- lora.png +0 -0
- lora.py +1222 -0
- requirements.txt +4 -0
- sdxl_loras.json +279 -0
- share_btn.py +76 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.vscode
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: LoRA the Explorer
|
| 3 |
+
emoji: 🔎 🖼️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.39.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
duplicated_from: multimodalart/LoraTheExplorer
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
| 6 |
+
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
|
| 7 |
+
import lora
|
| 8 |
+
from time import sleep
|
| 9 |
+
import copy
|
| 10 |
+
import json
|
| 11 |
+
import gc
|
| 12 |
+
|
| 13 |
+
with open("sdxl_loras.json", "r") as file:
|
| 14 |
+
data = json.load(file)
|
| 15 |
+
sdxl_loras = [
|
| 16 |
+
{
|
| 17 |
+
"image": item["image"],
|
| 18 |
+
"title": item["title"],
|
| 19 |
+
"repo": item["repo"],
|
| 20 |
+
"trigger_word": item["trigger_word"],
|
| 21 |
+
"weights": item["weights"],
|
| 22 |
+
"is_compatible": item["is_compatible"],
|
| 23 |
+
"is_pivotal": item.get("is_pivotal", False),
|
| 24 |
+
"text_embedding_weights": item.get("text_embedding_weights", None),
|
| 25 |
+
"is_nc": item.get("is_nc", False)
|
| 26 |
+
}
|
| 27 |
+
for item in data
|
| 28 |
+
]
|
| 29 |
+
print(sdxl_loras)
|
| 30 |
+
saved_names = [
|
| 31 |
+
hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
device = "cuda" # replace this to `mps` if on a MacOS Silicon
|
| 35 |
+
|
| 36 |
+
vae = AutoencoderKL.from_pretrained(
|
| 37 |
+
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
| 38 |
+
)
|
| 39 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 40 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 41 |
+
vae=vae,
|
| 42 |
+
torch_dtype=torch.float16,
|
| 43 |
+
).to("cpu")
|
| 44 |
+
original_pipe = copy.deepcopy(pipe)
|
| 45 |
+
pipe.to(device)
|
| 46 |
+
|
| 47 |
+
last_lora = ""
|
| 48 |
+
last_merged = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def update_selection(selected_state: gr.SelectData):
|
| 52 |
+
lora_repo = sdxl_loras[selected_state.index]["repo"]
|
| 53 |
+
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
|
| 54 |
+
new_placeholder = "Type a prompt. This LoRA applies for all prompts, no need for a trigger word" if instance_prompt == "" else "Type a prompt to use your selected LoRA"
|
| 55 |
+
weight_name = sdxl_loras[selected_state.index]["weights"]
|
| 56 |
+
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
|
| 57 |
+
is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
|
| 58 |
+
is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
|
| 59 |
+
|
| 60 |
+
use_with_diffusers = f'''
|
| 61 |
+
## Using [`{lora_repo}`](https://huggingface.co/{lora_repo})
|
| 62 |
+
|
| 63 |
+
## Use it with diffusers:
|
| 64 |
+
'''
|
| 65 |
+
if is_compatible:
|
| 66 |
+
use_with_diffusers += f'''
|
| 67 |
+
from diffusers import StableDiffusionXLPipeline
|
| 68 |
+
import torch
|
| 69 |
+
|
| 70 |
+
model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 71 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
| 72 |
+
pipe.to("cuda")
|
| 73 |
+
pipe.load_lora_weights("{lora_repo}", weight_name="{weight_name}")
|
| 74 |
+
|
| 75 |
+
prompt = "{instance_prompt}..."
|
| 76 |
+
lora_scale= 0.9
|
| 77 |
+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={{"scale": lora_scale}}).images[0]
|
| 78 |
+
image.save("image.png")
|
| 79 |
+
'''
|
| 80 |
+
elif not is_pivotal:
|
| 81 |
+
use_with_diffusers += "This LoRA is not compatible with diffusers natively yet. But you can still use it on diffusers with `bmaltais/kohya_ss` LoRA class, check out this [Google Colab](https://colab.research.google.com/drive/14aEJsKdEQ9_kyfsiV6JDok799kxPul0j )"
|
| 82 |
+
else:
|
| 83 |
+
use_with_diffusers += f"This LoRA is not compatible with diffusers natively yet. But you can still use it on diffusers with sdxl-cog `TokenEmbeddingsHandler` class, check out the [model repo](https://huggingface.co/{lora_repo}#inference-with-🧨-diffusers)"
|
| 84 |
+
use_with_uis = f'''
|
| 85 |
+
## Use it with Comfy UI, Invoke AI, SD.Next, AUTO1111:
|
| 86 |
+
|
| 87 |
+
### Download the `*.safetensors` weights of [here](https://huggingface.co/{lora_repo}/resolve/main/{weight_name})
|
| 88 |
+
|
| 89 |
+
- [ComfyUI guide](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
| 90 |
+
- [Invoke AI guide](https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/?h=lora#using-loras)
|
| 91 |
+
- [SD.Next guide](https://github.com/vladmandic/automatic)
|
| 92 |
+
- [AUTOMATIC1111 guide](https://stable-diffusion-art.com/lora/)
|
| 93 |
+
'''
|
| 94 |
+
return (
|
| 95 |
+
updated_text,
|
| 96 |
+
instance_prompt,
|
| 97 |
+
gr.update(placeholder=new_placeholder),
|
| 98 |
+
selected_state,
|
| 99 |
+
use_with_diffusers,
|
| 100 |
+
use_with_uis,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_selected(selected_state):
|
| 105 |
+
if not selected_state:
|
| 106 |
+
raise gr.Error("You must select a LoRA")
|
| 107 |
+
|
| 108 |
+
def merge_incompatible_lora(full_path_lora, lora_scale):
|
| 109 |
+
for weights_file in [full_path_lora]:
|
| 110 |
+
if ";" in weights_file:
|
| 111 |
+
weights_file, multiplier = weights_file.split(";")
|
| 112 |
+
multiplier = float(multiplier)
|
| 113 |
+
else:
|
| 114 |
+
multiplier = lora_scale
|
| 115 |
+
|
| 116 |
+
lora_model, weights_sd = lora.create_network_from_weights(
|
| 117 |
+
multiplier,
|
| 118 |
+
full_path_lora,
|
| 119 |
+
pipe.vae,
|
| 120 |
+
pipe.text_encoder,
|
| 121 |
+
pipe.unet,
|
| 122 |
+
for_inference=True,
|
| 123 |
+
)
|
| 124 |
+
lora_model.merge_to(
|
| 125 |
+
pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
|
| 126 |
+
)
|
| 127 |
+
del weights_sd
|
| 128 |
+
del lora_model
|
| 129 |
+
gc.collect()
|
| 130 |
+
|
| 131 |
+
def run_lora(prompt, negative, lora_scale, selected_state):
|
| 132 |
+
global last_lora, last_merged, pipe
|
| 133 |
+
|
| 134 |
+
if negative == "":
|
| 135 |
+
negative = None
|
| 136 |
+
|
| 137 |
+
if not selected_state:
|
| 138 |
+
raise gr.Error("You must select a LoRA")
|
| 139 |
+
repo_name = sdxl_loras[selected_state.index]["repo"]
|
| 140 |
+
weight_name = sdxl_loras[selected_state.index]["weights"]
|
| 141 |
+
full_path_lora = saved_names[selected_state.index]
|
| 142 |
+
cross_attention_kwargs = None
|
| 143 |
+
if last_lora != repo_name:
|
| 144 |
+
if last_merged:
|
| 145 |
+
del pipe
|
| 146 |
+
gc.collect()
|
| 147 |
+
pipe = copy.deepcopy(original_pipe)
|
| 148 |
+
pipe.to(device)
|
| 149 |
+
else:
|
| 150 |
+
pipe.unload_lora_weights()
|
| 151 |
+
is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
|
| 152 |
+
if is_compatible:
|
| 153 |
+
pipe.load_lora_weights(full_path_lora)
|
| 154 |
+
cross_attention_kwargs = {"scale": lora_scale}
|
| 155 |
+
else:
|
| 156 |
+
is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
|
| 157 |
+
if(is_pivotal):
|
| 158 |
+
|
| 159 |
+
pipe.load_lora_weights(full_path_lora)
|
| 160 |
+
cross_attention_kwargs = {"scale": lora_scale}
|
| 161 |
+
|
| 162 |
+
#Add the textual inversion embeddings from pivotal tuning models
|
| 163 |
+
text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
|
| 164 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
| 165 |
+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
| 166 |
+
embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
|
| 167 |
+
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
|
| 168 |
+
embhandler.load_embeddings(embedding_path)
|
| 169 |
+
else:
|
| 170 |
+
merge_incompatible_lora(full_path_lora, lora_scale)
|
| 171 |
+
last_merged = True
|
| 172 |
+
|
| 173 |
+
image = pipe(
|
| 174 |
+
prompt=prompt,
|
| 175 |
+
negative_prompt=negative,
|
| 176 |
+
width=768,
|
| 177 |
+
height=768,
|
| 178 |
+
num_inference_steps=20,
|
| 179 |
+
guidance_scale=7.5,
|
| 180 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 181 |
+
).images[0]
|
| 182 |
+
last_lora = repo_name
|
| 183 |
+
gc.collect()
|
| 184 |
+
return image, gr.update(visible=True)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
with gr.Blocks(css="custom.css") as demo:
|
| 188 |
+
title = gr.HTML(
|
| 189 |
+
"""<h1><img src="https://i.imgur.com/vT48NAO.png" alt="LoRA"> LoRA the Explorer</h1>""",
|
| 190 |
+
elem_id="title",
|
| 191 |
+
)
|
| 192 |
+
selected_state = gr.State()
|
| 193 |
+
with gr.Row():
|
| 194 |
+
gallery = gr.Gallery(
|
| 195 |
+
value=[(item["image"], item["title"]) for item in sdxl_loras],
|
| 196 |
+
label="SDXL LoRA Gallery",
|
| 197 |
+
allow_preview=False,
|
| 198 |
+
columns=3,
|
| 199 |
+
elem_id="gallery",
|
| 200 |
+
show_share_button=False
|
| 201 |
+
)
|
| 202 |
+
with gr.Column():
|
| 203 |
+
prompt_title = gr.Markdown(
|
| 204 |
+
value="### Click on a LoRA in the gallery to select it",
|
| 205 |
+
visible=True,
|
| 206 |
+
elem_id="selected_lora",
|
| 207 |
+
)
|
| 208 |
+
with gr.Row():
|
| 209 |
+
prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, placeholder="Type a prompt after selecting a LoRA", elem_id="prompt")
|
| 210 |
+
button = gr.Button("Run", elem_id="run_button")
|
| 211 |
+
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
|
| 212 |
+
community_icon = gr.HTML(community_icon_html)
|
| 213 |
+
loading_icon = gr.HTML(loading_icon_html)
|
| 214 |
+
share_button = gr.Button("Share to community", elem_id="share-btn")
|
| 215 |
+
result = gr.Image(
|
| 216 |
+
interactive=False, label="Generated Image", elem_id="result-image"
|
| 217 |
+
)
|
| 218 |
+
with gr.Accordion("Advanced options", open=False):
|
| 219 |
+
negative = gr.Textbox(label="Negative Prompt")
|
| 220 |
+
weight = gr.Slider(0, 10, value=1, step=0.1, label="LoRA weight")
|
| 221 |
+
|
| 222 |
+
with gr.Column(elem_id="extra_info"):
|
| 223 |
+
with gr.Accordion(
|
| 224 |
+
"Use it with: 🧨 diffusers, ComfyUI, Invoke AI, SD.Next, AUTO1111",
|
| 225 |
+
open=False,
|
| 226 |
+
elem_id="accordion",
|
| 227 |
+
):
|
| 228 |
+
with gr.Row():
|
| 229 |
+
use_diffusers = gr.Markdown("""## Select a LoRA first 🤗""")
|
| 230 |
+
use_uis = gr.Markdown()
|
| 231 |
+
with gr.Accordion("Submit a LoRA! 📥", open=False):
|
| 232 |
+
submit_title = gr.Markdown(
|
| 233 |
+
"### Streamlined submission coming soon! Until then [suggest your LoRA in the community tab](https://huggingface.co/spaces/multimodalart/LoraTheExplorer/discussions) 🤗"
|
| 234 |
+
)
|
| 235 |
+
with gr.Box(elem_id="soon"):
|
| 236 |
+
submit_source = gr.Radio(
|
| 237 |
+
["Hugging Face", "CivitAI"],
|
| 238 |
+
label="LoRA source",
|
| 239 |
+
value="Hugging Face",
|
| 240 |
+
)
|
| 241 |
+
with gr.Row():
|
| 242 |
+
submit_source_hf = gr.Textbox(
|
| 243 |
+
label="Hugging Face Model Repo",
|
| 244 |
+
info="In the format `username/model_id`",
|
| 245 |
+
)
|
| 246 |
+
submit_safetensors_hf = gr.Textbox(
|
| 247 |
+
label="Safetensors filename",
|
| 248 |
+
info="The filename `*.safetensors` in the model repo",
|
| 249 |
+
)
|
| 250 |
+
with gr.Row():
|
| 251 |
+
submit_trigger_word_hf = gr.Textbox(label="Trigger word")
|
| 252 |
+
submit_image = gr.Image(
|
| 253 |
+
label="Example image (optional if the repo already contains images)"
|
| 254 |
+
)
|
| 255 |
+
submit_button = gr.Button("Submit!")
|
| 256 |
+
submit_disclaimer = gr.Markdown(
|
| 257 |
+
"This is a curated gallery by me, [apolinário (multimodal.art)](https://twitter.com/multimodalart). I'll try to include as many cool LoRAs as they are submitted! You can [duplicate this Space](https://huggingface.co/spaces/multimodalart/LoraTheExplorer?duplicate=true) to use it privately, and add your own LoRAs by editing `sdxl_loras.json` in the Files tab of your private space."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
gallery.select(
|
| 261 |
+
update_selection,
|
| 262 |
+
outputs=[prompt_title, prompt, prompt, selected_state, use_diffusers, use_uis],
|
| 263 |
+
queue=False,
|
| 264 |
+
show_progress=False,
|
| 265 |
+
)
|
| 266 |
+
prompt.submit(
|
| 267 |
+
fn=check_selected,
|
| 268 |
+
inputs=[selected_state],
|
| 269 |
+
queue=False,
|
| 270 |
+
show_progress=False
|
| 271 |
+
).success(
|
| 272 |
+
fn=run_lora,
|
| 273 |
+
inputs=[prompt, negative, weight, selected_state],
|
| 274 |
+
outputs=[result, share_group],
|
| 275 |
+
)
|
| 276 |
+
button.click(
|
| 277 |
+
fn=check_selected,
|
| 278 |
+
inputs=[selected_state],
|
| 279 |
+
queue=False,
|
| 280 |
+
show_progress=False
|
| 281 |
+
).success(
|
| 282 |
+
fn=run_lora,
|
| 283 |
+
inputs=[prompt, negative, weight, selected_state],
|
| 284 |
+
outputs=[result, share_group],
|
| 285 |
+
)
|
| 286 |
+
share_button.click(None, [], [], _js=share_js)
|
| 287 |
+
|
| 288 |
+
demo.queue(max_size=20)
|
| 289 |
+
demo.launch()
|
cog_sdxl_dataset_and_utils.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset_and_utils.py file taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import PIL
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
from safetensors.torch import save_file
|
| 14 |
+
from torch.utils.data import Dataset
|
| 15 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def prepare_image(
|
| 19 |
+
pil_image: PIL.Image.Image, w: int = 512, h: int = 512
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
| 22 |
+
arr = np.array(pil_image.convert("RGB"))
|
| 23 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
| 24 |
+
arr = np.transpose(arr, [2, 0, 1])
|
| 25 |
+
image = torch.from_numpy(arr).unsqueeze(0)
|
| 26 |
+
return image
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def prepare_mask(
|
| 30 |
+
pil_image: PIL.Image.Image, w: int = 512, h: int = 512
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
| 33 |
+
arr = np.array(pil_image.convert("L"))
|
| 34 |
+
arr = arr.astype(np.float32) / 255.0
|
| 35 |
+
arr = np.expand_dims(arr, 0)
|
| 36 |
+
image = torch.from_numpy(arr).unsqueeze(0)
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PreprocessedDataset(Dataset):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
csv_path: str,
|
| 44 |
+
tokenizer_1,
|
| 45 |
+
tokenizer_2,
|
| 46 |
+
vae_encoder,
|
| 47 |
+
text_encoder_1=None,
|
| 48 |
+
text_encoder_2=None,
|
| 49 |
+
do_cache: bool = False,
|
| 50 |
+
size: int = 512,
|
| 51 |
+
text_dropout: float = 0.0,
|
| 52 |
+
scale_vae_latents: bool = True,
|
| 53 |
+
substitute_caption_map: Dict[str, str] = {},
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.data = pd.read_csv(csv_path)
|
| 58 |
+
self.csv_path = csv_path
|
| 59 |
+
|
| 60 |
+
self.caption = self.data["caption"]
|
| 61 |
+
# make it lowercase
|
| 62 |
+
self.caption = self.caption.str.lower()
|
| 63 |
+
for key, value in substitute_caption_map.items():
|
| 64 |
+
self.caption = self.caption.str.replace(key.lower(), value)
|
| 65 |
+
|
| 66 |
+
self.image_path = self.data["image_path"]
|
| 67 |
+
|
| 68 |
+
if "mask_path" not in self.data.columns:
|
| 69 |
+
self.mask_path = None
|
| 70 |
+
else:
|
| 71 |
+
self.mask_path = self.data["mask_path"]
|
| 72 |
+
|
| 73 |
+
if text_encoder_1 is None:
|
| 74 |
+
self.return_text_embeddings = False
|
| 75 |
+
else:
|
| 76 |
+
self.text_encoder_1 = text_encoder_1
|
| 77 |
+
self.text_encoder_2 = text_encoder_2
|
| 78 |
+
self.return_text_embeddings = True
|
| 79 |
+
assert (
|
| 80 |
+
NotImplementedError
|
| 81 |
+
), "Preprocessing Text Encoder is not implemented yet"
|
| 82 |
+
|
| 83 |
+
self.tokenizer_1 = tokenizer_1
|
| 84 |
+
self.tokenizer_2 = tokenizer_2
|
| 85 |
+
|
| 86 |
+
self.vae_encoder = vae_encoder
|
| 87 |
+
self.scale_vae_latents = scale_vae_latents
|
| 88 |
+
self.text_dropout = text_dropout
|
| 89 |
+
|
| 90 |
+
self.size = size
|
| 91 |
+
|
| 92 |
+
if do_cache:
|
| 93 |
+
self.vae_latents = []
|
| 94 |
+
self.tokens_tuple = []
|
| 95 |
+
self.masks = []
|
| 96 |
+
|
| 97 |
+
self.do_cache = True
|
| 98 |
+
|
| 99 |
+
print("Captions to train on: ")
|
| 100 |
+
for idx in range(len(self.data)):
|
| 101 |
+
token, vae_latent, mask = self._process(idx)
|
| 102 |
+
self.vae_latents.append(vae_latent)
|
| 103 |
+
self.tokens_tuple.append(token)
|
| 104 |
+
self.masks.append(mask)
|
| 105 |
+
|
| 106 |
+
del self.vae_encoder
|
| 107 |
+
|
| 108 |
+
else:
|
| 109 |
+
self.do_cache = False
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def _process(
|
| 113 |
+
self, idx: int
|
| 114 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 115 |
+
image_path = self.image_path[idx]
|
| 116 |
+
image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
|
| 117 |
+
|
| 118 |
+
image = PIL.Image.open(image_path).convert("RGB")
|
| 119 |
+
image = prepare_image(image, self.size, self.size).to(
|
| 120 |
+
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
caption = self.caption[idx]
|
| 124 |
+
|
| 125 |
+
print(caption)
|
| 126 |
+
|
| 127 |
+
# tokenizer_1
|
| 128 |
+
ti1 = self.tokenizer_1(
|
| 129 |
+
caption,
|
| 130 |
+
padding="max_length",
|
| 131 |
+
max_length=77,
|
| 132 |
+
truncation=True,
|
| 133 |
+
add_special_tokens=True,
|
| 134 |
+
return_tensors="pt",
|
| 135 |
+
).input_ids
|
| 136 |
+
|
| 137 |
+
ti2 = self.tokenizer_2(
|
| 138 |
+
caption,
|
| 139 |
+
padding="max_length",
|
| 140 |
+
max_length=77,
|
| 141 |
+
truncation=True,
|
| 142 |
+
add_special_tokens=True,
|
| 143 |
+
return_tensors="pt",
|
| 144 |
+
).input_ids
|
| 145 |
+
|
| 146 |
+
vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
|
| 147 |
+
|
| 148 |
+
if self.scale_vae_latents:
|
| 149 |
+
vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
|
| 150 |
+
|
| 151 |
+
if self.mask_path is None:
|
| 152 |
+
mask = torch.ones_like(
|
| 153 |
+
vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
else:
|
| 157 |
+
mask_path = self.mask_path[idx]
|
| 158 |
+
mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
|
| 159 |
+
|
| 160 |
+
mask = PIL.Image.open(mask_path)
|
| 161 |
+
mask = prepare_mask(mask, self.size, self.size).to(
|
| 162 |
+
dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
mask = torch.nn.functional.interpolate(
|
| 166 |
+
mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
|
| 167 |
+
)
|
| 168 |
+
mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
|
| 169 |
+
|
| 170 |
+
assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
|
| 171 |
+
|
| 172 |
+
return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
|
| 173 |
+
|
| 174 |
+
def __len__(self) -> int:
|
| 175 |
+
return len(self.data)
|
| 176 |
+
|
| 177 |
+
def atidx(
|
| 178 |
+
self, idx: int
|
| 179 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 180 |
+
if self.do_cache:
|
| 181 |
+
return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
|
| 182 |
+
else:
|
| 183 |
+
return self._process(idx)
|
| 184 |
+
|
| 185 |
+
def __getitem__(
|
| 186 |
+
self, idx: int
|
| 187 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 188 |
+
token, vae_latent, mask = self.atidx(idx)
|
| 189 |
+
return token, vae_latent, mask
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def import_model_class_from_model_name_or_path(
|
| 193 |
+
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
| 194 |
+
):
|
| 195 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 196 |
+
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
|
| 197 |
+
)
|
| 198 |
+
model_class = text_encoder_config.architectures[0]
|
| 199 |
+
|
| 200 |
+
if model_class == "CLIPTextModel":
|
| 201 |
+
from transformers import CLIPTextModel
|
| 202 |
+
|
| 203 |
+
return CLIPTextModel
|
| 204 |
+
elif model_class == "CLIPTextModelWithProjection":
|
| 205 |
+
from transformers import CLIPTextModelWithProjection
|
| 206 |
+
|
| 207 |
+
return CLIPTextModelWithProjection
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
|
| 213 |
+
tokenizer_one = AutoTokenizer.from_pretrained(
|
| 214 |
+
pretrained_model_name_or_path,
|
| 215 |
+
subfolder="tokenizer",
|
| 216 |
+
revision=revision,
|
| 217 |
+
use_fast=False,
|
| 218 |
+
)
|
| 219 |
+
tokenizer_two = AutoTokenizer.from_pretrained(
|
| 220 |
+
pretrained_model_name_or_path,
|
| 221 |
+
subfolder="tokenizer_2",
|
| 222 |
+
revision=revision,
|
| 223 |
+
use_fast=False,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Load scheduler and models
|
| 227 |
+
noise_scheduler = DDPMScheduler.from_pretrained(
|
| 228 |
+
pretrained_model_name_or_path, subfolder="scheduler"
|
| 229 |
+
)
|
| 230 |
+
# import correct text encoder classes
|
| 231 |
+
text_encoder_cls_one = import_model_class_from_model_name_or_path(
|
| 232 |
+
pretrained_model_name_or_path, revision
|
| 233 |
+
)
|
| 234 |
+
text_encoder_cls_two = import_model_class_from_model_name_or_path(
|
| 235 |
+
pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
|
| 236 |
+
)
|
| 237 |
+
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
| 238 |
+
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
|
| 239 |
+
)
|
| 240 |
+
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
| 241 |
+
pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
vae = AutoencoderKL.from_pretrained(
|
| 245 |
+
pretrained_model_name_or_path, subfolder="vae", revision=revision
|
| 246 |
+
)
|
| 247 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 248 |
+
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
vae.requires_grad_(False)
|
| 252 |
+
text_encoder_one.requires_grad_(False)
|
| 253 |
+
text_encoder_two.requires_grad_(False)
|
| 254 |
+
|
| 255 |
+
unet.to(device, dtype=weight_dtype)
|
| 256 |
+
vae.to(device, dtype=torch.float32)
|
| 257 |
+
text_encoder_one.to(device, dtype=weight_dtype)
|
| 258 |
+
text_encoder_two.to(device, dtype=weight_dtype)
|
| 259 |
+
|
| 260 |
+
return (
|
| 261 |
+
tokenizer_one,
|
| 262 |
+
tokenizer_two,
|
| 263 |
+
noise_scheduler,
|
| 264 |
+
text_encoder_one,
|
| 265 |
+
text_encoder_two,
|
| 266 |
+
vae,
|
| 267 |
+
unet,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
| 272 |
+
"""
|
| 273 |
+
Returns:
|
| 274 |
+
a state dict containing just the attention processor parameters.
|
| 275 |
+
"""
|
| 276 |
+
attn_processors = unet.attn_processors
|
| 277 |
+
|
| 278 |
+
attn_processors_state_dict = {}
|
| 279 |
+
|
| 280 |
+
for attn_processor_key, attn_processor in attn_processors.items():
|
| 281 |
+
for parameter_key, parameter in attn_processor.state_dict().items():
|
| 282 |
+
attn_processors_state_dict[
|
| 283 |
+
f"{attn_processor_key}.{parameter_key}"
|
| 284 |
+
] = parameter
|
| 285 |
+
|
| 286 |
+
return attn_processors_state_dict
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class TokenEmbeddingsHandler:
|
| 290 |
+
def __init__(self, text_encoders, tokenizers):
|
| 291 |
+
self.text_encoders = text_encoders
|
| 292 |
+
self.tokenizers = tokenizers
|
| 293 |
+
|
| 294 |
+
self.train_ids: Optional[torch.Tensor] = None
|
| 295 |
+
self.inserting_toks: Optional[List[str]] = None
|
| 296 |
+
self.embeddings_settings = {}
|
| 297 |
+
|
| 298 |
+
def initialize_new_tokens(self, inserting_toks: List[str]):
|
| 299 |
+
idx = 0
|
| 300 |
+
for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
|
| 301 |
+
assert isinstance(
|
| 302 |
+
inserting_toks, list
|
| 303 |
+
), "inserting_toks should be a list of strings."
|
| 304 |
+
assert all(
|
| 305 |
+
isinstance(tok, str) for tok in inserting_toks
|
| 306 |
+
), "All elements in inserting_toks should be strings."
|
| 307 |
+
|
| 308 |
+
self.inserting_toks = inserting_toks
|
| 309 |
+
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
| 310 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
| 311 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 312 |
+
|
| 313 |
+
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
| 314 |
+
|
| 315 |
+
# random initialization of new tokens
|
| 316 |
+
|
| 317 |
+
std_token_embedding = (
|
| 318 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data.std()
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
|
| 322 |
+
|
| 323 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 324 |
+
self.train_ids
|
| 325 |
+
] = (
|
| 326 |
+
torch.randn(
|
| 327 |
+
len(self.train_ids), text_encoder.text_model.config.hidden_size
|
| 328 |
+
)
|
| 329 |
+
.to(device=self.device)
|
| 330 |
+
.to(dtype=self.dtype)
|
| 331 |
+
* std_token_embedding
|
| 332 |
+
)
|
| 333 |
+
self.embeddings_settings[
|
| 334 |
+
f"original_embeddings_{idx}"
|
| 335 |
+
] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
|
| 336 |
+
self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
|
| 337 |
+
|
| 338 |
+
inu = torch.ones((len(tokenizer),), dtype=torch.bool)
|
| 339 |
+
inu[self.train_ids] = False
|
| 340 |
+
|
| 341 |
+
self.embeddings_settings[f"index_no_updates_{idx}"] = inu
|
| 342 |
+
|
| 343 |
+
print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
|
| 344 |
+
|
| 345 |
+
idx += 1
|
| 346 |
+
|
| 347 |
+
def save_embeddings(self, file_path: str):
|
| 348 |
+
assert (
|
| 349 |
+
self.train_ids is not None
|
| 350 |
+
), "Initialize new tokens before saving embeddings."
|
| 351 |
+
tensors = {}
|
| 352 |
+
for idx, text_encoder in enumerate(self.text_encoders):
|
| 353 |
+
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
|
| 354 |
+
0
|
| 355 |
+
] == len(self.tokenizers[0]), "Tokenizers should be the same."
|
| 356 |
+
new_token_embeddings = (
|
| 357 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 358 |
+
self.train_ids
|
| 359 |
+
]
|
| 360 |
+
)
|
| 361 |
+
tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
| 362 |
+
|
| 363 |
+
save_file(tensors, file_path)
|
| 364 |
+
|
| 365 |
+
@property
|
| 366 |
+
def dtype(self):
|
| 367 |
+
return self.text_encoders[0].dtype
|
| 368 |
+
|
| 369 |
+
@property
|
| 370 |
+
def device(self):
|
| 371 |
+
return self.text_encoders[0].device
|
| 372 |
+
|
| 373 |
+
def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
|
| 374 |
+
# Assuming new tokens are of the format <s_i>
|
| 375 |
+
self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
|
| 376 |
+
special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
| 377 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
| 378 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 379 |
+
|
| 380 |
+
self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
| 381 |
+
assert self.train_ids is not None, "New tokens could not be converted to IDs."
|
| 382 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 383 |
+
self.train_ids
|
| 384 |
+
] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
|
| 385 |
+
|
| 386 |
+
@torch.no_grad()
|
| 387 |
+
def retract_embeddings(self):
|
| 388 |
+
for idx, text_encoder in enumerate(self.text_encoders):
|
| 389 |
+
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
| 390 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 391 |
+
index_no_updates
|
| 392 |
+
] = (
|
| 393 |
+
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
| 394 |
+
.to(device=text_encoder.device)
|
| 395 |
+
.to(dtype=text_encoder.dtype)
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# for the parts that were updated, we need to normalize them
|
| 399 |
+
# to have the same std as before
|
| 400 |
+
std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
|
| 401 |
+
|
| 402 |
+
index_updates = ~index_no_updates
|
| 403 |
+
new_embeddings = (
|
| 404 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 405 |
+
index_updates
|
| 406 |
+
]
|
| 407 |
+
)
|
| 408 |
+
off_ratio = std_token_embedding / new_embeddings.std()
|
| 409 |
+
|
| 410 |
+
new_embeddings = new_embeddings * (off_ratio**0.1)
|
| 411 |
+
text_encoder.text_model.embeddings.token_embedding.weight.data[
|
| 412 |
+
index_updates
|
| 413 |
+
] = new_embeddings
|
| 414 |
+
|
| 415 |
+
def load_embeddings(self, file_path: str):
|
| 416 |
+
with safe_open(file_path, framework="pt", device=self.device.type) as f:
|
| 417 |
+
for idx in range(len(self.text_encoders)):
|
| 418 |
+
text_encoder = self.text_encoders[idx]
|
| 419 |
+
tokenizer = self.tokenizers[idx]
|
| 420 |
+
|
| 421 |
+
loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
|
| 422 |
+
self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
|
custom.css
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#title{text-align: center;}
|
| 2 |
+
#title h1{font-size: 3em; display:inline-flex; align-items:center}
|
| 3 |
+
#title img{width: 100px; margin-right: 0.5em}
|
| 4 |
+
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
|
| 5 |
+
#run_button{position:absolute;margin-top: 11px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
|
| 6 |
+
border-top-left-radius: 0px;}
|
| 7 |
+
#gallery{display:flex}
|
| 8 |
+
#gallery .grid-wrap{min-height: 100%;}
|
| 9 |
+
#accordion code{word-break: break-all;word-wrap: break-word;white-space: pre-wrap}
|
| 10 |
+
#soon{opacity: 0.55; pointer-events: none}
|
| 11 |
+
#soon button{width: 100%}
|
| 12 |
+
#share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
|
| 13 |
+
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
|
| 14 |
+
#share-btn-container:hover {background-color: #060606}
|
| 15 |
+
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
|
| 16 |
+
#share-btn * {all: unset}
|
| 17 |
+
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
|
| 18 |
+
#share-btn-container .wrap {display: none !important}
|
| 19 |
+
#share-btn-container.hidden {display: none!important}
|
| 20 |
+
#extra_info{margin-top: 1em}
|
| 21 |
+
.pending .min {min-height: auto}
|
| 22 |
+
|
| 23 |
+
@media (max-width: 527px) {
|
| 24 |
+
#title h1{font-size: 2.2em}
|
| 25 |
+
#title img{width: 80px;}
|
| 26 |
+
#gallery {max-height: 370px}
|
| 27 |
+
}
|
images/3d_style_4.jpeg
ADDED
|
images/LineAni.Redmond.png
ADDED
|
images/LogoRedmond-LogoLoraForSDXL.jpeg
ADDED
|
images/ToyRedmond-ToyLoraForSDXL10.png
ADDED
|
images/corgi_brick.jpeg
ADDED
|
images/crayon.png
ADDED
|
images/dog.png
ADDED
|
images/embroid.png
ADDED
|
images/jojoso1.jpg
ADDED
|
images/josef_koudelka.webp
ADDED
|
images/lego-minifig-xl.jpeg
ADDED
|
images/papercut_SDXL.jpeg
ADDED
|
images/pikachu.webp
ADDED
|
images/pixel-art-xl.jpeg
ADDED
|
|
images/riding-min.jpg
ADDED
|
images/the_fish.jpg
ADDED
|
images/uglysonic.webp
ADDED
|
images/voxel-xl-lora.png
ADDED
|
images/watercolor.png
ADDED
|
images/william_eggleston.webp
ADDED
|
lora.png
ADDED
|
lora.py
ADDED
|
@@ -0,0 +1,1222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LoRA network module taken from https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py
|
| 2 |
+
# reference:
|
| 3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
| 4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 9 |
+
from diffusers import AutoencoderKL
|
| 10 |
+
from transformers import CLIPTextModel
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 17 |
+
|
| 18 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LoRAModule(torch.nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
lora_name,
|
| 29 |
+
org_module: torch.nn.Module,
|
| 30 |
+
multiplier=1.0,
|
| 31 |
+
lora_dim=4,
|
| 32 |
+
alpha=1,
|
| 33 |
+
dropout=None,
|
| 34 |
+
rank_dropout=None,
|
| 35 |
+
module_dropout=None,
|
| 36 |
+
):
|
| 37 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.lora_name = lora_name
|
| 40 |
+
|
| 41 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 42 |
+
in_dim = org_module.in_channels
|
| 43 |
+
out_dim = org_module.out_channels
|
| 44 |
+
else:
|
| 45 |
+
in_dim = org_module.in_features
|
| 46 |
+
out_dim = org_module.out_features
|
| 47 |
+
|
| 48 |
+
# if limit_rank:
|
| 49 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
| 50 |
+
# if self.lora_dim != lora_dim:
|
| 51 |
+
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
| 52 |
+
# else:
|
| 53 |
+
self.lora_dim = lora_dim
|
| 54 |
+
|
| 55 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 56 |
+
kernel_size = org_module.kernel_size
|
| 57 |
+
stride = org_module.stride
|
| 58 |
+
padding = org_module.padding
|
| 59 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 60 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 61 |
+
else:
|
| 62 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 63 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 64 |
+
|
| 65 |
+
if type(alpha) == torch.Tensor:
|
| 66 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 67 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 68 |
+
self.scale = alpha / self.lora_dim
|
| 69 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
| 70 |
+
|
| 71 |
+
# same as microsoft's
|
| 72 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 73 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 74 |
+
|
| 75 |
+
self.multiplier = multiplier
|
| 76 |
+
self.org_module = org_module # remove in applying
|
| 77 |
+
self.dropout = dropout
|
| 78 |
+
self.rank_dropout = rank_dropout
|
| 79 |
+
self.module_dropout = module_dropout
|
| 80 |
+
|
| 81 |
+
def apply_to(self):
|
| 82 |
+
self.org_forward = self.org_module.forward
|
| 83 |
+
self.org_module.forward = self.forward
|
| 84 |
+
del self.org_module
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
org_forwarded = self.org_forward(x)
|
| 88 |
+
|
| 89 |
+
# module dropout
|
| 90 |
+
if self.module_dropout is not None and self.training:
|
| 91 |
+
if torch.rand(1) < self.module_dropout:
|
| 92 |
+
return org_forwarded
|
| 93 |
+
|
| 94 |
+
lx = self.lora_down(x)
|
| 95 |
+
|
| 96 |
+
# normal dropout
|
| 97 |
+
if self.dropout is not None and self.training:
|
| 98 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
| 99 |
+
|
| 100 |
+
# rank dropout
|
| 101 |
+
if self.rank_dropout is not None and self.training:
|
| 102 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
| 103 |
+
if len(lx.size()) == 3:
|
| 104 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
| 105 |
+
elif len(lx.size()) == 4:
|
| 106 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
| 107 |
+
lx = lx * mask
|
| 108 |
+
|
| 109 |
+
# scaling for rank dropout: treat as if the rank is changed
|
| 110 |
+
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
| 111 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
| 112 |
+
else:
|
| 113 |
+
scale = self.scale
|
| 114 |
+
|
| 115 |
+
lx = self.lora_up(lx)
|
| 116 |
+
|
| 117 |
+
return org_forwarded + lx * self.multiplier * scale
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class LoRAInfModule(LoRAModule):
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
lora_name,
|
| 124 |
+
org_module: torch.nn.Module,
|
| 125 |
+
multiplier=1.0,
|
| 126 |
+
lora_dim=4,
|
| 127 |
+
alpha=1,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
# no dropout for inference
|
| 131 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
| 132 |
+
|
| 133 |
+
self.org_module_ref = [org_module] # 後から参照できるように
|
| 134 |
+
self.enabled = True
|
| 135 |
+
|
| 136 |
+
# check regional or not by lora_name
|
| 137 |
+
self.text_encoder = False
|
| 138 |
+
if lora_name.startswith("lora_te_"):
|
| 139 |
+
self.regional = False
|
| 140 |
+
self.use_sub_prompt = True
|
| 141 |
+
self.text_encoder = True
|
| 142 |
+
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
| 143 |
+
self.regional = False
|
| 144 |
+
self.use_sub_prompt = True
|
| 145 |
+
elif "time_emb" in lora_name:
|
| 146 |
+
self.regional = False
|
| 147 |
+
self.use_sub_prompt = False
|
| 148 |
+
else:
|
| 149 |
+
self.regional = True
|
| 150 |
+
self.use_sub_prompt = False
|
| 151 |
+
|
| 152 |
+
self.network: LoRANetwork = None
|
| 153 |
+
|
| 154 |
+
def set_network(self, network):
|
| 155 |
+
self.network = network
|
| 156 |
+
|
| 157 |
+
# freezeしてマージする
|
| 158 |
+
def merge_to(self, sd, dtype, device):
|
| 159 |
+
# get up/down weight
|
| 160 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
| 161 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
| 162 |
+
|
| 163 |
+
# extract weight from org_module
|
| 164 |
+
org_sd = self.org_module.state_dict()
|
| 165 |
+
weight = org_sd["weight"].to(torch.float)
|
| 166 |
+
|
| 167 |
+
# merge weight
|
| 168 |
+
if len(weight.size()) == 2:
|
| 169 |
+
# linear
|
| 170 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
| 171 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 172 |
+
# conv2d 1x1
|
| 173 |
+
weight = (
|
| 174 |
+
weight
|
| 175 |
+
+ self.multiplier
|
| 176 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 177 |
+
* self.scale
|
| 178 |
+
)
|
| 179 |
+
else:
|
| 180 |
+
# conv2d 3x3
|
| 181 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 182 |
+
# print(conved.size(), weight.size(), module.stride, module.padding)
|
| 183 |
+
weight = weight + self.multiplier * conved * self.scale
|
| 184 |
+
|
| 185 |
+
# set weight to org_module
|
| 186 |
+
org_sd["weight"] = weight.to(dtype)
|
| 187 |
+
self.org_module.load_state_dict(org_sd)
|
| 188 |
+
|
| 189 |
+
# 復元できるマージのため、このモジュールのweightを返す
|
| 190 |
+
def get_weight(self, multiplier=None):
|
| 191 |
+
if multiplier is None:
|
| 192 |
+
multiplier = self.multiplier
|
| 193 |
+
|
| 194 |
+
# get up/down weight from module
|
| 195 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
| 196 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
| 197 |
+
|
| 198 |
+
# pre-calculated weight
|
| 199 |
+
if len(down_weight.size()) == 2:
|
| 200 |
+
# linear
|
| 201 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
| 202 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 203 |
+
# conv2d 1x1
|
| 204 |
+
weight = (
|
| 205 |
+
self.multiplier
|
| 206 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 207 |
+
* self.scale
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
# conv2d 3x3
|
| 211 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 212 |
+
weight = self.multiplier * conved * self.scale
|
| 213 |
+
|
| 214 |
+
return weight
|
| 215 |
+
|
| 216 |
+
def set_region(self, region):
|
| 217 |
+
self.region = region
|
| 218 |
+
self.region_mask = None
|
| 219 |
+
|
| 220 |
+
def default_forward(self, x):
|
| 221 |
+
# print("default_forward", self.lora_name, x.size())
|
| 222 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 223 |
+
|
| 224 |
+
def forward(self, x):
|
| 225 |
+
if not self.enabled:
|
| 226 |
+
return self.org_forward(x)
|
| 227 |
+
|
| 228 |
+
if self.network is None or self.network.sub_prompt_index is None:
|
| 229 |
+
return self.default_forward(x)
|
| 230 |
+
if not self.regional and not self.use_sub_prompt:
|
| 231 |
+
return self.default_forward(x)
|
| 232 |
+
|
| 233 |
+
if self.regional:
|
| 234 |
+
return self.regional_forward(x)
|
| 235 |
+
else:
|
| 236 |
+
return self.sub_prompt_forward(x)
|
| 237 |
+
|
| 238 |
+
def get_mask_for_x(self, x):
|
| 239 |
+
# calculate size from shape of x
|
| 240 |
+
if len(x.size()) == 4:
|
| 241 |
+
h, w = x.size()[2:4]
|
| 242 |
+
area = h * w
|
| 243 |
+
else:
|
| 244 |
+
area = x.size()[1]
|
| 245 |
+
|
| 246 |
+
mask = self.network.mask_dic[area]
|
| 247 |
+
if mask is None:
|
| 248 |
+
raise ValueError(f"mask is None for resolution {area}")
|
| 249 |
+
if len(x.size()) != 4:
|
| 250 |
+
mask = torch.reshape(mask, (1, -1, 1))
|
| 251 |
+
return mask
|
| 252 |
+
|
| 253 |
+
def regional_forward(self, x):
|
| 254 |
+
if "attn2_to_out" in self.lora_name:
|
| 255 |
+
return self.to_out_forward(x)
|
| 256 |
+
|
| 257 |
+
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
| 258 |
+
return self.default_forward(x)
|
| 259 |
+
|
| 260 |
+
# apply mask for LoRA result
|
| 261 |
+
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 262 |
+
mask = self.get_mask_for_x(lx)
|
| 263 |
+
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
| 264 |
+
lx = lx * mask
|
| 265 |
+
|
| 266 |
+
x = self.org_forward(x)
|
| 267 |
+
x = x + lx
|
| 268 |
+
|
| 269 |
+
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
| 270 |
+
x = self.postp_to_q(x)
|
| 271 |
+
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
def postp_to_q(self, x):
|
| 275 |
+
# repeat x to num_sub_prompts
|
| 276 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
| 277 |
+
qc = self.network.batch_size # uncond
|
| 278 |
+
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
| 279 |
+
if has_real_uncond:
|
| 280 |
+
qc += self.network.batch_size # real_uncond
|
| 281 |
+
|
| 282 |
+
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
| 283 |
+
query[: self.network.batch_size] = x[: self.network.batch_size]
|
| 284 |
+
|
| 285 |
+
for i in range(self.network.batch_size):
|
| 286 |
+
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 287 |
+
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
| 288 |
+
|
| 289 |
+
if has_real_uncond:
|
| 290 |
+
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
| 291 |
+
|
| 292 |
+
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
| 293 |
+
return query
|
| 294 |
+
|
| 295 |
+
def sub_prompt_forward(self, x):
|
| 296 |
+
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
| 297 |
+
return self.org_forward(x)
|
| 298 |
+
|
| 299 |
+
emb_idx = self.network.sub_prompt_index
|
| 300 |
+
if not self.text_encoder:
|
| 301 |
+
emb_idx += self.network.batch_size
|
| 302 |
+
|
| 303 |
+
# apply sub prompt of X
|
| 304 |
+
lx = x[emb_idx :: self.network.num_sub_prompts]
|
| 305 |
+
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
| 306 |
+
|
| 307 |
+
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
| 308 |
+
|
| 309 |
+
x = self.org_forward(x)
|
| 310 |
+
x[emb_idx :: self.network.num_sub_prompts] += lx
|
| 311 |
+
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
def to_out_forward(self, x):
|
| 315 |
+
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
| 316 |
+
|
| 317 |
+
if self.network.is_last_network:
|
| 318 |
+
masks = [None] * self.network.num_sub_prompts
|
| 319 |
+
self.network.shared[self.lora_name] = (None, masks)
|
| 320 |
+
else:
|
| 321 |
+
lx, masks = self.network.shared[self.lora_name]
|
| 322 |
+
|
| 323 |
+
# call own LoRA
|
| 324 |
+
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
| 325 |
+
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
| 326 |
+
|
| 327 |
+
if self.network.is_last_network:
|
| 328 |
+
lx = torch.zeros(
|
| 329 |
+
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
| 330 |
+
)
|
| 331 |
+
self.network.shared[self.lora_name] = (lx, masks)
|
| 332 |
+
|
| 333 |
+
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
| 334 |
+
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
| 335 |
+
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
| 336 |
+
|
| 337 |
+
# if not last network, return x and masks
|
| 338 |
+
x = self.org_forward(x)
|
| 339 |
+
if not self.network.is_last_network:
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
lx, masks = self.network.shared.pop(self.lora_name)
|
| 343 |
+
|
| 344 |
+
# if last network, combine separated x with mask weighted sum
|
| 345 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
| 346 |
+
|
| 347 |
+
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
| 348 |
+
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
| 349 |
+
if has_real_uncond:
|
| 350 |
+
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
| 351 |
+
|
| 352 |
+
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
| 353 |
+
# for i in range(len(masks)):
|
| 354 |
+
# if masks[i] is None:
|
| 355 |
+
# masks[i] = torch.zeros_like(masks[-1])
|
| 356 |
+
|
| 357 |
+
mask = torch.cat(masks)
|
| 358 |
+
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
| 359 |
+
for i in range(self.network.batch_size):
|
| 360 |
+
# 1枚の画像ごとに処理する
|
| 361 |
+
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
| 362 |
+
lx1 = lx1 * mask
|
| 363 |
+
lx1 = torch.sum(lx1, dim=0)
|
| 364 |
+
|
| 365 |
+
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 366 |
+
x1 = x[xi : xi + self.network.num_sub_prompts]
|
| 367 |
+
x1 = x1 * mask
|
| 368 |
+
x1 = torch.sum(x1, dim=0)
|
| 369 |
+
x1 = x1 / mask_sum
|
| 370 |
+
|
| 371 |
+
x1 = x1 + lx1
|
| 372 |
+
out[self.network.batch_size + i] = x1
|
| 373 |
+
|
| 374 |
+
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
| 375 |
+
return out
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def parse_block_lr_kwargs(nw_kwargs):
|
| 379 |
+
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
| 380 |
+
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
| 381 |
+
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
| 382 |
+
|
| 383 |
+
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
| 384 |
+
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
| 385 |
+
return None, None, None
|
| 386 |
+
|
| 387 |
+
# extract learning rate weight for each block
|
| 388 |
+
if down_lr_weight is not None:
|
| 389 |
+
# if some parameters are not set, use zero
|
| 390 |
+
if "," in down_lr_weight:
|
| 391 |
+
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
| 392 |
+
|
| 393 |
+
if mid_lr_weight is not None:
|
| 394 |
+
mid_lr_weight = float(mid_lr_weight)
|
| 395 |
+
|
| 396 |
+
if up_lr_weight is not None:
|
| 397 |
+
if "," in up_lr_weight:
|
| 398 |
+
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
| 399 |
+
|
| 400 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
| 401 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def create_network(
|
| 408 |
+
multiplier: float,
|
| 409 |
+
network_dim: Optional[int],
|
| 410 |
+
network_alpha: Optional[float],
|
| 411 |
+
vae: AutoencoderKL,
|
| 412 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
| 413 |
+
unet,
|
| 414 |
+
neuron_dropout: Optional[float] = None,
|
| 415 |
+
**kwargs,
|
| 416 |
+
):
|
| 417 |
+
if network_dim is None:
|
| 418 |
+
network_dim = 4 # default
|
| 419 |
+
if network_alpha is None:
|
| 420 |
+
network_alpha = 1.0
|
| 421 |
+
|
| 422 |
+
# extract dim/alpha for conv2d, and block dim
|
| 423 |
+
conv_dim = kwargs.get("conv_dim", None)
|
| 424 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
| 425 |
+
if conv_dim is not None:
|
| 426 |
+
conv_dim = int(conv_dim)
|
| 427 |
+
if conv_alpha is None:
|
| 428 |
+
conv_alpha = 1.0
|
| 429 |
+
else:
|
| 430 |
+
conv_alpha = float(conv_alpha)
|
| 431 |
+
|
| 432 |
+
# block dim/alpha/lr
|
| 433 |
+
block_dims = kwargs.get("block_dims", None)
|
| 434 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
| 435 |
+
|
| 436 |
+
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
| 437 |
+
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
| 438 |
+
block_alphas = kwargs.get("block_alphas", None)
|
| 439 |
+
conv_block_dims = kwargs.get("conv_block_dims", None)
|
| 440 |
+
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
| 441 |
+
|
| 442 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
| 443 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# remove block dim/alpha without learning rate
|
| 447 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
| 448 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
else:
|
| 452 |
+
block_alphas = None
|
| 453 |
+
conv_block_dims = None
|
| 454 |
+
conv_block_alphas = None
|
| 455 |
+
|
| 456 |
+
# rank/module dropout
|
| 457 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
| 458 |
+
if rank_dropout is not None:
|
| 459 |
+
rank_dropout = float(rank_dropout)
|
| 460 |
+
module_dropout = kwargs.get("module_dropout", None)
|
| 461 |
+
if module_dropout is not None:
|
| 462 |
+
module_dropout = float(module_dropout)
|
| 463 |
+
|
| 464 |
+
# すごく引数が多いな ( ^ω^)・・・
|
| 465 |
+
network = LoRANetwork(
|
| 466 |
+
text_encoder,
|
| 467 |
+
unet,
|
| 468 |
+
multiplier=multiplier,
|
| 469 |
+
lora_dim=network_dim,
|
| 470 |
+
alpha=network_alpha,
|
| 471 |
+
dropout=neuron_dropout,
|
| 472 |
+
rank_dropout=rank_dropout,
|
| 473 |
+
module_dropout=module_dropout,
|
| 474 |
+
conv_lora_dim=conv_dim,
|
| 475 |
+
conv_alpha=conv_alpha,
|
| 476 |
+
block_dims=block_dims,
|
| 477 |
+
block_alphas=block_alphas,
|
| 478 |
+
conv_block_dims=conv_block_dims,
|
| 479 |
+
conv_block_alphas=conv_block_alphas,
|
| 480 |
+
varbose=True,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
| 484 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
| 485 |
+
|
| 486 |
+
return network
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
# このメソッドは外部から呼び出される可能性を考慮しておく
|
| 490 |
+
# network_dim, network_alpha にはデフォルト値が入っている。
|
| 491 |
+
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
| 492 |
+
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
| 493 |
+
def get_block_dims_and_alphas(
|
| 494 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 495 |
+
):
|
| 496 |
+
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
| 497 |
+
|
| 498 |
+
def parse_ints(s):
|
| 499 |
+
return [int(i) for i in s.split(",")]
|
| 500 |
+
|
| 501 |
+
def parse_floats(s):
|
| 502 |
+
return [float(i) for i in s.split(",")]
|
| 503 |
+
|
| 504 |
+
# block_dimsとblock_alphasをパースする。必ず値が入る
|
| 505 |
+
if block_dims is not None:
|
| 506 |
+
block_dims = parse_ints(block_dims)
|
| 507 |
+
assert (
|
| 508 |
+
len(block_dims) == num_total_blocks
|
| 509 |
+
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
| 510 |
+
else:
|
| 511 |
+
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
| 512 |
+
block_dims = [network_dim] * num_total_blocks
|
| 513 |
+
|
| 514 |
+
if block_alphas is not None:
|
| 515 |
+
block_alphas = parse_floats(block_alphas)
|
| 516 |
+
assert (
|
| 517 |
+
len(block_alphas) == num_total_blocks
|
| 518 |
+
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
| 519 |
+
else:
|
| 520 |
+
print(
|
| 521 |
+
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になり���す"
|
| 522 |
+
)
|
| 523 |
+
block_alphas = [network_alpha] * num_total_blocks
|
| 524 |
+
|
| 525 |
+
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
| 526 |
+
if conv_block_dims is not None:
|
| 527 |
+
conv_block_dims = parse_ints(conv_block_dims)
|
| 528 |
+
assert (
|
| 529 |
+
len(conv_block_dims) == num_total_blocks
|
| 530 |
+
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
| 531 |
+
|
| 532 |
+
if conv_block_alphas is not None:
|
| 533 |
+
conv_block_alphas = parse_floats(conv_block_alphas)
|
| 534 |
+
assert (
|
| 535 |
+
len(conv_block_alphas) == num_total_blocks
|
| 536 |
+
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
| 537 |
+
else:
|
| 538 |
+
if conv_alpha is None:
|
| 539 |
+
conv_alpha = 1.0
|
| 540 |
+
print(
|
| 541 |
+
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
| 542 |
+
)
|
| 543 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 544 |
+
else:
|
| 545 |
+
if conv_dim is not None:
|
| 546 |
+
print(
|
| 547 |
+
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
| 548 |
+
)
|
| 549 |
+
conv_block_dims = [conv_dim] * num_total_blocks
|
| 550 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 551 |
+
else:
|
| 552 |
+
conv_block_dims = None
|
| 553 |
+
conv_block_alphas = None
|
| 554 |
+
|
| 555 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
| 559 |
+
def get_block_lr_weight(
|
| 560 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
| 561 |
+
) -> Tuple[List[float], List[float], List[float]]:
|
| 562 |
+
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
| 563 |
+
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
| 564 |
+
return None, None, None
|
| 565 |
+
|
| 566 |
+
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
| 567 |
+
|
| 568 |
+
def get_list(name_with_suffix) -> List[float]:
|
| 569 |
+
import math
|
| 570 |
+
|
| 571 |
+
tokens = name_with_suffix.split("+")
|
| 572 |
+
name = tokens[0]
|
| 573 |
+
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
| 574 |
+
|
| 575 |
+
if name == "cosine":
|
| 576 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
| 577 |
+
elif name == "sine":
|
| 578 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
| 579 |
+
elif name == "linear":
|
| 580 |
+
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
| 581 |
+
elif name == "reverse_linear":
|
| 582 |
+
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
| 583 |
+
elif name == "zeros":
|
| 584 |
+
return [0.0 + base_lr] * max_len
|
| 585 |
+
else:
|
| 586 |
+
print(
|
| 587 |
+
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
| 588 |
+
% (name)
|
| 589 |
+
)
|
| 590 |
+
return None
|
| 591 |
+
|
| 592 |
+
if type(down_lr_weight) == str:
|
| 593 |
+
down_lr_weight = get_list(down_lr_weight)
|
| 594 |
+
if type(up_lr_weight) == str:
|
| 595 |
+
up_lr_weight = get_list(up_lr_weight)
|
| 596 |
+
|
| 597 |
+
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
| 598 |
+
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
| 599 |
+
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
| 600 |
+
up_lr_weight = up_lr_weight[:max_len]
|
| 601 |
+
down_lr_weight = down_lr_weight[:max_len]
|
| 602 |
+
|
| 603 |
+
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
| 604 |
+
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
| 605 |
+
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
| 606 |
+
|
| 607 |
+
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
| 608 |
+
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
| 609 |
+
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
| 610 |
+
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
| 611 |
+
|
| 612 |
+
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
| 613 |
+
print("apply block learning rate / 階層別学習率を適用します。")
|
| 614 |
+
if down_lr_weight != None:
|
| 615 |
+
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
| 616 |
+
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
| 617 |
+
else:
|
| 618 |
+
print("down_lr_weight: all 1.0, すべて1.0")
|
| 619 |
+
|
| 620 |
+
if mid_lr_weight != None:
|
| 621 |
+
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
| 622 |
+
print("mid_lr_weight:", mid_lr_weight)
|
| 623 |
+
else:
|
| 624 |
+
print("mid_lr_weight: 1.0")
|
| 625 |
+
|
| 626 |
+
if up_lr_weight != None:
|
| 627 |
+
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
| 628 |
+
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
| 629 |
+
else:
|
| 630 |
+
print("up_lr_weight: all 1.0, すべて1.0")
|
| 631 |
+
|
| 632 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
| 636 |
+
def remove_block_dims_and_alphas(
|
| 637 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
| 638 |
+
):
|
| 639 |
+
# set 0 to block dim without learning rate to remove the block
|
| 640 |
+
if down_lr_weight != None:
|
| 641 |
+
for i, lr in enumerate(down_lr_weight):
|
| 642 |
+
if lr == 0:
|
| 643 |
+
block_dims[i] = 0
|
| 644 |
+
if conv_block_dims is not None:
|
| 645 |
+
conv_block_dims[i] = 0
|
| 646 |
+
if mid_lr_weight != None:
|
| 647 |
+
if mid_lr_weight == 0:
|
| 648 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
| 649 |
+
if conv_block_dims is not None:
|
| 650 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
| 651 |
+
if up_lr_weight != None:
|
| 652 |
+
for i, lr in enumerate(up_lr_weight):
|
| 653 |
+
if lr == 0:
|
| 654 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
| 655 |
+
if conv_block_dims is not None:
|
| 656 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
| 657 |
+
|
| 658 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# 外部から呼び出す可能性を考慮しておく
|
| 662 |
+
def get_block_index(lora_name: str) -> int:
|
| 663 |
+
block_idx = -1 # invalid lora name
|
| 664 |
+
|
| 665 |
+
m = RE_UPDOWN.search(lora_name)
|
| 666 |
+
if m:
|
| 667 |
+
g = m.groups()
|
| 668 |
+
i = int(g[1])
|
| 669 |
+
j = int(g[3])
|
| 670 |
+
if g[2] == "resnets":
|
| 671 |
+
idx = 3 * i + j
|
| 672 |
+
elif g[2] == "attentions":
|
| 673 |
+
idx = 3 * i + j
|
| 674 |
+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
| 675 |
+
idx = 3 * i + 2
|
| 676 |
+
|
| 677 |
+
if g[0] == "down":
|
| 678 |
+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
| 679 |
+
elif g[0] == "up":
|
| 680 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
| 681 |
+
|
| 682 |
+
elif "mid_block_" in lora_name:
|
| 683 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
| 684 |
+
|
| 685 |
+
return block_idx
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
| 689 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
| 690 |
+
if weights_sd is None:
|
| 691 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 692 |
+
from safetensors.torch import load_file, safe_open
|
| 693 |
+
|
| 694 |
+
weights_sd = load_file(file)
|
| 695 |
+
else:
|
| 696 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 697 |
+
|
| 698 |
+
# get dim/alpha mapping
|
| 699 |
+
modules_dim = {}
|
| 700 |
+
modules_alpha = {}
|
| 701 |
+
for key, value in weights_sd.items():
|
| 702 |
+
if "." not in key:
|
| 703 |
+
continue
|
| 704 |
+
|
| 705 |
+
lora_name = key.split(".")[0]
|
| 706 |
+
if "alpha" in key:
|
| 707 |
+
modules_alpha[lora_name] = value
|
| 708 |
+
elif "lora_down" in key:
|
| 709 |
+
dim = value.size()[0]
|
| 710 |
+
modules_dim[lora_name] = dim
|
| 711 |
+
# print(lora_name, value.size(), dim)
|
| 712 |
+
|
| 713 |
+
# support old LoRA without alpha
|
| 714 |
+
for key in modules_dim.keys():
|
| 715 |
+
if key not in modules_alpha:
|
| 716 |
+
modules_alpha[key] = modules_dim[key]
|
| 717 |
+
|
| 718 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
| 719 |
+
|
| 720 |
+
network = LoRANetwork(
|
| 721 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
# block lr
|
| 725 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
| 726 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
| 727 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
| 728 |
+
|
| 729 |
+
return network, weights_sd
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class LoRANetwork(torch.nn.Module):
|
| 733 |
+
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
| 734 |
+
|
| 735 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 736 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 737 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
| 738 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 739 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 740 |
+
|
| 741 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 742 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 743 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 744 |
+
|
| 745 |
+
def __init__(
|
| 746 |
+
self,
|
| 747 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 748 |
+
unet,
|
| 749 |
+
multiplier: float = 1.0,
|
| 750 |
+
lora_dim: int = 4,
|
| 751 |
+
alpha: float = 1,
|
| 752 |
+
dropout: Optional[float] = None,
|
| 753 |
+
rank_dropout: Optional[float] = None,
|
| 754 |
+
module_dropout: Optional[float] = None,
|
| 755 |
+
conv_lora_dim: Optional[int] = None,
|
| 756 |
+
conv_alpha: Optional[float] = None,
|
| 757 |
+
block_dims: Optional[List[int]] = None,
|
| 758 |
+
block_alphas: Optional[List[float]] = None,
|
| 759 |
+
conv_block_dims: Optional[List[int]] = None,
|
| 760 |
+
conv_block_alphas: Optional[List[float]] = None,
|
| 761 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
| 762 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
| 763 |
+
module_class: Type[object] = LoRAModule,
|
| 764 |
+
varbose: Optional[bool] = False,
|
| 765 |
+
) -> None:
|
| 766 |
+
"""
|
| 767 |
+
LoRA network: すごく引数が多いが、パターンは以下の通り
|
| 768 |
+
1. lora_dimとalphaを指定
|
| 769 |
+
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
| 770 |
+
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
| 771 |
+
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
| 772 |
+
5. modules_dimとmodules_alphaを指定 (推論用)
|
| 773 |
+
"""
|
| 774 |
+
super().__init__()
|
| 775 |
+
self.multiplier = multiplier
|
| 776 |
+
|
| 777 |
+
self.lora_dim = lora_dim
|
| 778 |
+
self.alpha = alpha
|
| 779 |
+
self.conv_lora_dim = conv_lora_dim
|
| 780 |
+
self.conv_alpha = conv_alpha
|
| 781 |
+
self.dropout = dropout
|
| 782 |
+
self.rank_dropout = rank_dropout
|
| 783 |
+
self.module_dropout = module_dropout
|
| 784 |
+
|
| 785 |
+
if modules_dim is not None:
|
| 786 |
+
print(f"create LoRA network from weights")
|
| 787 |
+
elif block_dims is not None:
|
| 788 |
+
print(f"create LoRA network from block_dims")
|
| 789 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
| 790 |
+
print(f"block_dims: {block_dims}")
|
| 791 |
+
print(f"block_alphas: {block_alphas}")
|
| 792 |
+
if conv_block_dims is not None:
|
| 793 |
+
print(f"conv_block_dims: {conv_block_dims}")
|
| 794 |
+
print(f"conv_block_alphas: {conv_block_alphas}")
|
| 795 |
+
else:
|
| 796 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 797 |
+
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
| 798 |
+
if self.conv_lora_dim is not None:
|
| 799 |
+
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 800 |
+
|
| 801 |
+
# create module instances
|
| 802 |
+
def create_modules(
|
| 803 |
+
is_unet: bool,
|
| 804 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
| 805 |
+
root_module: torch.nn.Module,
|
| 806 |
+
target_replace_modules: List[torch.nn.Module],
|
| 807 |
+
) -> List[LoRAModule]:
|
| 808 |
+
prefix = (
|
| 809 |
+
self.LORA_PREFIX_UNET
|
| 810 |
+
if is_unet
|
| 811 |
+
else (
|
| 812 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
| 813 |
+
if text_encoder_idx is None
|
| 814 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 815 |
+
)
|
| 816 |
+
)
|
| 817 |
+
loras = []
|
| 818 |
+
skipped = []
|
| 819 |
+
for name, module in root_module.named_modules():
|
| 820 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 821 |
+
for child_name, child_module in module.named_modules():
|
| 822 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 823 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 824 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 825 |
+
|
| 826 |
+
if is_linear or is_conv2d:
|
| 827 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 828 |
+
lora_name = lora_name.replace(".", "_")
|
| 829 |
+
|
| 830 |
+
dim = None
|
| 831 |
+
alpha = None
|
| 832 |
+
|
| 833 |
+
if modules_dim is not None:
|
| 834 |
+
# モジュール指定あり
|
| 835 |
+
if lora_name in modules_dim:
|
| 836 |
+
dim = modules_dim[lora_name]
|
| 837 |
+
alpha = modules_alpha[lora_name]
|
| 838 |
+
elif is_unet and block_dims is not None:
|
| 839 |
+
# U-Netでblock_dims指定あり
|
| 840 |
+
block_idx = get_block_index(lora_name)
|
| 841 |
+
if is_linear or is_conv2d_1x1:
|
| 842 |
+
dim = block_dims[block_idx]
|
| 843 |
+
alpha = block_alphas[block_idx]
|
| 844 |
+
elif conv_block_dims is not None:
|
| 845 |
+
dim = conv_block_dims[block_idx]
|
| 846 |
+
alpha = conv_block_alphas[block_idx]
|
| 847 |
+
else:
|
| 848 |
+
# 通常、すべて対象とする
|
| 849 |
+
if is_linear or is_conv2d_1x1:
|
| 850 |
+
dim = self.lora_dim
|
| 851 |
+
alpha = self.alpha
|
| 852 |
+
elif self.conv_lora_dim is not None:
|
| 853 |
+
dim = self.conv_lora_dim
|
| 854 |
+
alpha = self.conv_alpha
|
| 855 |
+
|
| 856 |
+
if dim is None or dim == 0:
|
| 857 |
+
# skipした情報を出力
|
| 858 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
| 859 |
+
skipped.append(lora_name)
|
| 860 |
+
continue
|
| 861 |
+
|
| 862 |
+
lora = module_class(
|
| 863 |
+
lora_name,
|
| 864 |
+
child_module,
|
| 865 |
+
self.multiplier,
|
| 866 |
+
dim,
|
| 867 |
+
alpha,
|
| 868 |
+
dropout=dropout,
|
| 869 |
+
rank_dropout=rank_dropout,
|
| 870 |
+
module_dropout=module_dropout,
|
| 871 |
+
)
|
| 872 |
+
loras.append(lora)
|
| 873 |
+
return loras, skipped
|
| 874 |
+
|
| 875 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 876 |
+
print(text_encoders)
|
| 877 |
+
# create LoRA for text encoder
|
| 878 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
| 879 |
+
self.text_encoder_loras = []
|
| 880 |
+
skipped_te = []
|
| 881 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 882 |
+
if len(text_encoders) > 1:
|
| 883 |
+
index = i + 1
|
| 884 |
+
print(f"create LoRA for Text Encoder {index}:")
|
| 885 |
+
else:
|
| 886 |
+
index = None
|
| 887 |
+
print(f"create LoRA for Text Encoder:")
|
| 888 |
+
|
| 889 |
+
print(text_encoder)
|
| 890 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 891 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 892 |
+
skipped_te += skipped
|
| 893 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 894 |
+
|
| 895 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 896 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 897 |
+
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
| 898 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 899 |
+
|
| 900 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 901 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 902 |
+
|
| 903 |
+
skipped = skipped_te + skipped_un
|
| 904 |
+
if varbose and len(skipped) > 0:
|
| 905 |
+
print(
|
| 906 |
+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
| 907 |
+
)
|
| 908 |
+
for name in skipped:
|
| 909 |
+
print(f"\t{name}")
|
| 910 |
+
|
| 911 |
+
self.up_lr_weight: List[float] = None
|
| 912 |
+
self.down_lr_weight: List[float] = None
|
| 913 |
+
self.mid_lr_weight: float = None
|
| 914 |
+
self.block_lr = False
|
| 915 |
+
|
| 916 |
+
# assertion
|
| 917 |
+
names = set()
|
| 918 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 919 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 920 |
+
names.add(lora.lora_name)
|
| 921 |
+
|
| 922 |
+
def set_multiplier(self, multiplier):
|
| 923 |
+
self.multiplier = multiplier
|
| 924 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 925 |
+
lora.multiplier = self.multiplier
|
| 926 |
+
|
| 927 |
+
def load_weights(self, file):
|
| 928 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 929 |
+
from safetensors.torch import load_file
|
| 930 |
+
|
| 931 |
+
weights_sd = load_file(file)
|
| 932 |
+
else:
|
| 933 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 934 |
+
info = self.load_state_dict(weights_sd, False)
|
| 935 |
+
return info
|
| 936 |
+
|
| 937 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
| 938 |
+
if apply_text_encoder:
|
| 939 |
+
print("enable LoRA for text encoder")
|
| 940 |
+
else:
|
| 941 |
+
self.text_encoder_loras = []
|
| 942 |
+
|
| 943 |
+
if apply_unet:
|
| 944 |
+
print("enable LoRA for U-Net")
|
| 945 |
+
else:
|
| 946 |
+
self.unet_loras = []
|
| 947 |
+
|
| 948 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 949 |
+
lora.apply_to()
|
| 950 |
+
self.add_module(lora.lora_name, lora)
|
| 951 |
+
|
| 952 |
+
# マージできるかどうかを返す
|
| 953 |
+
def is_mergeable(self):
|
| 954 |
+
return True
|
| 955 |
+
|
| 956 |
+
# TODO refactor to common function with apply_to
|
| 957 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
| 958 |
+
apply_text_encoder = apply_unet = False
|
| 959 |
+
for key in weights_sd.keys():
|
| 960 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 961 |
+
apply_text_encoder = True
|
| 962 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
| 963 |
+
apply_unet = True
|
| 964 |
+
|
| 965 |
+
if apply_text_encoder:
|
| 966 |
+
print("enable LoRA for text encoder")
|
| 967 |
+
else:
|
| 968 |
+
self.text_encoder_loras = []
|
| 969 |
+
|
| 970 |
+
if apply_unet:
|
| 971 |
+
print("enable LoRA for U-Net")
|
| 972 |
+
else:
|
| 973 |
+
self.unet_loras = []
|
| 974 |
+
|
| 975 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 976 |
+
sd_for_lora = {}
|
| 977 |
+
for key in weights_sd.keys():
|
| 978 |
+
if key.startswith(lora.lora_name):
|
| 979 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
| 980 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
| 981 |
+
|
| 982 |
+
print(f"weights are merged")
|
| 983 |
+
|
| 984 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
| 985 |
+
def set_block_lr_weight(
|
| 986 |
+
self,
|
| 987 |
+
up_lr_weight: List[float] = None,
|
| 988 |
+
mid_lr_weight: float = None,
|
| 989 |
+
down_lr_weight: List[float] = None,
|
| 990 |
+
):
|
| 991 |
+
self.block_lr = True
|
| 992 |
+
self.down_lr_weight = down_lr_weight
|
| 993 |
+
self.mid_lr_weight = mid_lr_weight
|
| 994 |
+
self.up_lr_weight = up_lr_weight
|
| 995 |
+
|
| 996 |
+
def get_lr_weight(self, lora: LoRAModule) -> float:
|
| 997 |
+
lr_weight = 1.0
|
| 998 |
+
block_idx = get_block_index(lora.lora_name)
|
| 999 |
+
if block_idx < 0:
|
| 1000 |
+
return lr_weight
|
| 1001 |
+
|
| 1002 |
+
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
| 1003 |
+
if self.down_lr_weight != None:
|
| 1004 |
+
lr_weight = self.down_lr_weight[block_idx]
|
| 1005 |
+
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
| 1006 |
+
if self.mid_lr_weight != None:
|
| 1007 |
+
lr_weight = self.mid_lr_weight
|
| 1008 |
+
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
| 1009 |
+
if self.up_lr_weight != None:
|
| 1010 |
+
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
| 1011 |
+
|
| 1012 |
+
return lr_weight
|
| 1013 |
+
|
| 1014 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
| 1015 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 1016 |
+
self.requires_grad_(True)
|
| 1017 |
+
all_params = []
|
| 1018 |
+
|
| 1019 |
+
def enumerate_params(loras):
|
| 1020 |
+
params = []
|
| 1021 |
+
for lora in loras:
|
| 1022 |
+
params.extend(lora.parameters())
|
| 1023 |
+
return params
|
| 1024 |
+
|
| 1025 |
+
if self.text_encoder_loras:
|
| 1026 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
| 1027 |
+
if text_encoder_lr is not None:
|
| 1028 |
+
param_data["lr"] = text_encoder_lr
|
| 1029 |
+
all_params.append(param_data)
|
| 1030 |
+
|
| 1031 |
+
if self.unet_loras:
|
| 1032 |
+
if self.block_lr:
|
| 1033 |
+
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
| 1034 |
+
block_idx_to_lora = {}
|
| 1035 |
+
for lora in self.unet_loras:
|
| 1036 |
+
idx = get_block_index(lora.lora_name)
|
| 1037 |
+
if idx not in block_idx_to_lora:
|
| 1038 |
+
block_idx_to_lora[idx] = []
|
| 1039 |
+
block_idx_to_lora[idx].append(lora)
|
| 1040 |
+
|
| 1041 |
+
# blockごとにパラメータを設定する
|
| 1042 |
+
for idx, block_loras in block_idx_to_lora.items():
|
| 1043 |
+
param_data = {"params": enumerate_params(block_loras)}
|
| 1044 |
+
|
| 1045 |
+
if unet_lr is not None:
|
| 1046 |
+
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
| 1047 |
+
elif default_lr is not None:
|
| 1048 |
+
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
| 1049 |
+
if ("lr" in param_data) and (param_data["lr"] == 0):
|
| 1050 |
+
continue
|
| 1051 |
+
all_params.append(param_data)
|
| 1052 |
+
|
| 1053 |
+
else:
|
| 1054 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
| 1055 |
+
if unet_lr is not None:
|
| 1056 |
+
param_data["lr"] = unet_lr
|
| 1057 |
+
all_params.append(param_data)
|
| 1058 |
+
|
| 1059 |
+
return all_params
|
| 1060 |
+
|
| 1061 |
+
def enable_gradient_checkpointing(self):
|
| 1062 |
+
# not supported
|
| 1063 |
+
pass
|
| 1064 |
+
|
| 1065 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
| 1066 |
+
self.requires_grad_(True)
|
| 1067 |
+
|
| 1068 |
+
def on_epoch_start(self, text_encoder, unet):
|
| 1069 |
+
self.train()
|
| 1070 |
+
|
| 1071 |
+
def get_trainable_params(self):
|
| 1072 |
+
return self.parameters()
|
| 1073 |
+
|
| 1074 |
+
def save_weights(self, file, dtype, metadata):
|
| 1075 |
+
if metadata is not None and len(metadata) == 0:
|
| 1076 |
+
metadata = None
|
| 1077 |
+
|
| 1078 |
+
state_dict = self.state_dict()
|
| 1079 |
+
|
| 1080 |
+
if dtype is not None:
|
| 1081 |
+
for key in list(state_dict.keys()):
|
| 1082 |
+
v = state_dict[key]
|
| 1083 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 1084 |
+
state_dict[key] = v
|
| 1085 |
+
|
| 1086 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1087 |
+
from safetensors.torch import save_file
|
| 1088 |
+
from library import train_util
|
| 1089 |
+
|
| 1090 |
+
# Precalculate model hashes to save time on indexing
|
| 1091 |
+
if metadata is None:
|
| 1092 |
+
metadata = {}
|
| 1093 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 1094 |
+
metadata["sshs_model_hash"] = model_hash
|
| 1095 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 1096 |
+
|
| 1097 |
+
save_file(state_dict, file, metadata)
|
| 1098 |
+
else:
|
| 1099 |
+
torch.save(state_dict, file)
|
| 1100 |
+
|
| 1101 |
+
# mask is a tensor with values from 0 to 1
|
| 1102 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
| 1103 |
+
if mask.max() == 0:
|
| 1104 |
+
mask = torch.ones_like(mask)
|
| 1105 |
+
|
| 1106 |
+
self.mask = mask
|
| 1107 |
+
self.sub_prompt_index = sub_prompt_index
|
| 1108 |
+
self.is_last_network = is_last_network
|
| 1109 |
+
|
| 1110 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1111 |
+
lora.set_network(self)
|
| 1112 |
+
|
| 1113 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
| 1114 |
+
self.batch_size = batch_size
|
| 1115 |
+
self.num_sub_prompts = num_sub_prompts
|
| 1116 |
+
self.current_size = (height, width)
|
| 1117 |
+
self.shared = shared
|
| 1118 |
+
|
| 1119 |
+
# create masks
|
| 1120 |
+
mask = self.mask
|
| 1121 |
+
mask_dic = {}
|
| 1122 |
+
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
| 1123 |
+
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
| 1124 |
+
dtype = ref_weight.dtype
|
| 1125 |
+
device = ref_weight.device
|
| 1126 |
+
|
| 1127 |
+
def resize_add(mh, mw):
|
| 1128 |
+
# print(mh, mw, mh * mw)
|
| 1129 |
+
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
| 1130 |
+
m = m.to(device, dtype=dtype)
|
| 1131 |
+
mask_dic[mh * mw] = m
|
| 1132 |
+
|
| 1133 |
+
h = height // 8
|
| 1134 |
+
w = width // 8
|
| 1135 |
+
for _ in range(4):
|
| 1136 |
+
resize_add(h, w)
|
| 1137 |
+
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
| 1138 |
+
resize_add(h + h % 2, w + w % 2)
|
| 1139 |
+
h = (h + 1) // 2
|
| 1140 |
+
w = (w + 1) // 2
|
| 1141 |
+
|
| 1142 |
+
self.mask_dic = mask_dic
|
| 1143 |
+
|
| 1144 |
+
def backup_weights(self):
|
| 1145 |
+
# 重みのバックアップを行う
|
| 1146 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1147 |
+
for lora in loras:
|
| 1148 |
+
org_module = lora.org_module_ref[0]
|
| 1149 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
| 1150 |
+
sd = org_module.state_dict()
|
| 1151 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
| 1152 |
+
org_module._lora_restored = True
|
| 1153 |
+
|
| 1154 |
+
def restore_weights(self):
|
| 1155 |
+
# 重みのリストアを行う
|
| 1156 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1157 |
+
for lora in loras:
|
| 1158 |
+
org_module = lora.org_module_ref[0]
|
| 1159 |
+
if not org_module._lora_restored:
|
| 1160 |
+
sd = org_module.state_dict()
|
| 1161 |
+
sd["weight"] = org_module._lora_org_weight
|
| 1162 |
+
org_module.load_state_dict(sd)
|
| 1163 |
+
org_module._lora_restored = True
|
| 1164 |
+
|
| 1165 |
+
def pre_calculation(self):
|
| 1166 |
+
# 事前計算を行う
|
| 1167 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1168 |
+
for lora in loras:
|
| 1169 |
+
org_module = lora.org_module_ref[0]
|
| 1170 |
+
sd = org_module.state_dict()
|
| 1171 |
+
|
| 1172 |
+
org_weight = sd["weight"]
|
| 1173 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
| 1174 |
+
sd["weight"] = org_weight + lora_weight
|
| 1175 |
+
assert sd["weight"].shape == org_weight.shape
|
| 1176 |
+
org_module.load_state_dict(sd)
|
| 1177 |
+
|
| 1178 |
+
org_module._lora_restored = False
|
| 1179 |
+
lora.enabled = False
|
| 1180 |
+
|
| 1181 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
| 1182 |
+
downkeys = []
|
| 1183 |
+
upkeys = []
|
| 1184 |
+
alphakeys = []
|
| 1185 |
+
norms = []
|
| 1186 |
+
keys_scaled = 0
|
| 1187 |
+
|
| 1188 |
+
state_dict = self.state_dict()
|
| 1189 |
+
for key in state_dict.keys():
|
| 1190 |
+
if "lora_down" in key and "weight" in key:
|
| 1191 |
+
downkeys.append(key)
|
| 1192 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
| 1193 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
| 1194 |
+
|
| 1195 |
+
for i in range(len(downkeys)):
|
| 1196 |
+
down = state_dict[downkeys[i]].to(device)
|
| 1197 |
+
up = state_dict[upkeys[i]].to(device)
|
| 1198 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
| 1199 |
+
dim = down.shape[0]
|
| 1200 |
+
scale = alpha / dim
|
| 1201 |
+
|
| 1202 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
| 1203 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 1204 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
| 1205 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
| 1206 |
+
else:
|
| 1207 |
+
updown = up @ down
|
| 1208 |
+
|
| 1209 |
+
updown *= scale
|
| 1210 |
+
|
| 1211 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
| 1212 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
| 1213 |
+
ratio = desired.cpu() / norm.cpu()
|
| 1214 |
+
sqrt_ratio = ratio**0.5
|
| 1215 |
+
if ratio != 1:
|
| 1216 |
+
keys_scaled += 1
|
| 1217 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
| 1218 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
| 1219 |
+
scalednorm = updown.norm() * ratio
|
| 1220 |
+
norms.append(scalednorm.item())
|
| 1221 |
+
|
| 1222 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/huggingface/diffusers.git@aedd78767c99f7bc26a532622d4006280cc6c00d
|
| 2 |
+
transformers
|
| 3 |
+
safetensors
|
| 4 |
+
accelerate
|
sdxl_loras.json
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"image": "images/pixel-art-xl.jpeg",
|
| 4 |
+
"title": "Pixel Art XL",
|
| 5 |
+
"repo": "nerijs/pixel-art-xl",
|
| 6 |
+
"trigger_word": "pixel art",
|
| 7 |
+
"weights": "pixel-art-xl.safetensors",
|
| 8 |
+
"is_compatible": true
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"image": "images/riding-min.jpg",
|
| 12 |
+
"title": "Tintin AI",
|
| 13 |
+
"repo": "Pclanglais/TintinIA",
|
| 14 |
+
"trigger_word": "drawing of tintin",
|
| 15 |
+
"weights": "pytorch_lora_weights.safetensors",
|
| 16 |
+
"is_compatible": true,
|
| 17 |
+
"is_nc": true
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"image": "https://huggingface.co/ProomptEngineer/pe-balloon-diffusion-style/resolve/main/2095176.jpeg",
|
| 21 |
+
"title": "PE Balloon Diffusion",
|
| 22 |
+
"repo": "ProomptEngineer/pe-balloon-diffusion-style",
|
| 23 |
+
"trigger_word": "PEBalloonStyle",
|
| 24 |
+
"weights": "PE_BalloonStyle.safetensors",
|
| 25 |
+
"is_compatible": true
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"image": "https://huggingface.co/joachimsallstrom/aether-cloud-lora-for-sdxl/resolve/main/2378710.jpeg",
|
| 29 |
+
"title": "Aether Cloud",
|
| 30 |
+
"repo": "joachimsallstrom/aether-cloud-lora-for-sdxl",
|
| 31 |
+
"trigger_word": "a cloud that looks like a",
|
| 32 |
+
"weights": "Aether_Cloud_v1.safetensors",
|
| 33 |
+
"is_compatible": true
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"image": "images/crayon.png",
|
| 37 |
+
"title": "Crayon Style",
|
| 38 |
+
"repo": "ostris/crayon_style_lora_sdxl",
|
| 39 |
+
"trigger_word": "",
|
| 40 |
+
"weights": "crayons_v1_sdxl.safetensors",
|
| 41 |
+
"is_compatible": true
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"image": "https://tjzk.replicate.delivery/models_models_cover_image/c8b21524-342a-4dd2-bb01-3e65349ed982/image_12.jpeg",
|
| 45 |
+
"title": "Zelda 64 SDXL",
|
| 46 |
+
"repo":"jbilcke-hf/sdxl-zelda64",
|
| 47 |
+
"trigger_word": "in the style of <s0><s1>",
|
| 48 |
+
"weights": "lora.safetensors",
|
| 49 |
+
"text_embedding_weights": "embeddings.pti",
|
| 50 |
+
"is_compatible": false,
|
| 51 |
+
"is_pivotal": true
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"image": "images/papercut_SDXL.jpeg",
|
| 55 |
+
"title": "Papercut SDXL",
|
| 56 |
+
"repo": "TheLastBen/Papercut_SDXL",
|
| 57 |
+
"trigger_word": "papercut",
|
| 58 |
+
"weights": "papercut.safetensors",
|
| 59 |
+
"is_compatible": true
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"image": "https://pbxt.replicate.delivery/8LKCty2D5b5BBBjylErfI8Xqf4OTSsnA0TIJccnpPct3GmeiA/out-0.png",
|
| 63 |
+
"title": "2004 bad digital photography",
|
| 64 |
+
"repo": "fofr/sdxl-2004",
|
| 65 |
+
"trigger_word": "2004, in the style of <s0><s1>",
|
| 66 |
+
"weights": "lora.safetensors",
|
| 67 |
+
"text_embedding_weights": "embeddings.pti",
|
| 68 |
+
"is_compatible": false,
|
| 69 |
+
"is_pivotal": true
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"image": "https://huggingface.co/joachimsallstrom/aether-ghost-lora-for-sdxl/resolve/14de4e59a3f44dabc762855da208cb8f44a7ac78/ghost.png",
|
| 73 |
+
"title": "Aether Ghost",
|
| 74 |
+
"repo": "joachimsallstrom/aether-ghost-lora-for-sdxl",
|
| 75 |
+
"trigger_word": "transparent ghost",
|
| 76 |
+
"weights": "Aether_Ghost_v1.1_LoRA.safetensors",
|
| 77 |
+
"is_compatible": true
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"image": "https://i.imgur.com/Su4bFgm.png",
|
| 81 |
+
"title": "Vulcan SDXL",
|
| 82 |
+
"repo": "davizca87/vulcan",
|
| 83 |
+
"trigger_word": "v5lcn",
|
| 84 |
+
"weights": "v5lcnXL-000004.safetensors",
|
| 85 |
+
"is_compatible": true
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"image":"https://huggingface.co/artificialguybr/ColoringBookRedmond/resolve/main/00009-1364020674.png",
|
| 89 |
+
"title": "ColoringBook.Redmond",
|
| 90 |
+
"repo": "artificialguybr/ColoringBookRedmond",
|
| 91 |
+
"trigger_word": "ColoringBookAF",
|
| 92 |
+
"weights": "ColoringBookRedmond-ColoringBookAF.safetensors",
|
| 93 |
+
"is_compatible": true
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"image": "https://huggingface.co/Norod78/SDXL-LofiGirl-Lora/resolve/main/SDXL-LofiGirl-Lora/Examples/_00044-20230829080050-45-the%20%20girl%20with%20a%20pearl%20earring%20the%20LofiGirl%20%20_lora_SDXL-LofiGirl-Lora_1_%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sharp%20image.jpg",
|
| 97 |
+
"title": "LoFi Girl SDXL",
|
| 98 |
+
"repo": "Norod78/SDXL-LofiGirl-Lora",
|
| 99 |
+
"trigger_word": "LofiGirl",
|
| 100 |
+
"weights": "SDXL-LofiGirl-Lora.safetensors",
|
| 101 |
+
"is_compatible": true
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"image": "images/embroid.png",
|
| 105 |
+
"title": "Embroidery Style",
|
| 106 |
+
"repo": "ostris/embroidery_style_lora_sdxl",
|
| 107 |
+
"trigger_word": "",
|
| 108 |
+
"weights": "embroidered_style_v1_sdxl.safetensors",
|
| 109 |
+
"is_compatible": true
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"image": "images/3d_style_4.jpeg",
|
| 113 |
+
"title": "3D Render Style",
|
| 114 |
+
"repo": "goofyai/3d_render_style_xl",
|
| 115 |
+
"trigger_word": "3d style",
|
| 116 |
+
"weights": "3d_render_style_xl.safetensors",
|
| 117 |
+
"is_compatible": true
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"image": "images/watercolor.png",
|
| 121 |
+
"title": "Watercolor Style",
|
| 122 |
+
"repo": "ostris/watercolor_style_lora_sdxl",
|
| 123 |
+
"trigger_word": "",
|
| 124 |
+
"weights": "watercolor_v1_sdxl.safetensors",
|
| 125 |
+
"is_compatible": true
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"image": "https://huggingface.co/veryVANYA/ps1-graphics-sdxl/resolve/main/2070471.jpeg",
|
| 129 |
+
"title": "PS1 Graphics v2 SDXL",
|
| 130 |
+
"repo":"veryVANYA/ps1-graphics-sdxl-v2",
|
| 131 |
+
"trigger_word": "ps1 style",
|
| 132 |
+
"weights": "ps1_style_SDXL_v2.safetensors",
|
| 133 |
+
"is_compatible": true
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"image": "images/william_eggleston.webp",
|
| 137 |
+
"title": "William Eggleston Style",
|
| 138 |
+
"repo": "TheLastBen/William_Eggleston_Style_SDXL",
|
| 139 |
+
"trigger_word": "by william eggleston",
|
| 140 |
+
"weights": "wegg.safetensors",
|
| 141 |
+
"is_compatible": true
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"image": "https://huggingface.co/davizca87/c-a-g-coinmaker/resolve/main/1722160.jpeg",
|
| 145 |
+
"title": "CAG Coinmaker",
|
| 146 |
+
"repo": "davizca87/c-a-g-coinmaker",
|
| 147 |
+
"trigger_word": "c01n",
|
| 148 |
+
"weights": "c01n-000010.safetensors",
|
| 149 |
+
"is_compatible": true
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"image": "images/dog.png",
|
| 153 |
+
"title": "Cyborg Style",
|
| 154 |
+
"repo": "goofyai/cyborg_style_xl",
|
| 155 |
+
"trigger_word": "cyborg style",
|
| 156 |
+
"weights": "cyborg_style_xl-off.safetensors",
|
| 157 |
+
"is_compatible": true
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"image": "images/ToyRedmond-ToyLoraForSDXL10.png",
|
| 161 |
+
"title": "Toy.Redmond",
|
| 162 |
+
"repo": "artificialguybr/ToyRedmond-ToyLoraForSDXL10",
|
| 163 |
+
"trigger_word": "FnkRedmAF",
|
| 164 |
+
"weights": "ToyRedmond-FnkRedmAF.safetensors",
|
| 165 |
+
"is_compatible": true
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"image": "images/voxel-xl-lora.png",
|
| 169 |
+
"title": "Voxel XL",
|
| 170 |
+
"repo": "Fictiverse/Voxel_XL_Lora",
|
| 171 |
+
"trigger_word": "voxel style",
|
| 172 |
+
"weights": "VoxelXL_v1.safetensors",
|
| 173 |
+
"is_compatible": true
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"image": "images/uglysonic.webp",
|
| 177 |
+
"title": "Ugly Sonic",
|
| 178 |
+
"repo": "minimaxir/sdxl-ugly-sonic-lora",
|
| 179 |
+
"trigger_word": "sonic the hedgehog",
|
| 180 |
+
"weights": "pytorch_lora_weights.bin",
|
| 181 |
+
"is_compatible": true
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"image": "images/corgi_brick.jpeg",
|
| 185 |
+
"title": "Lego BrickHeadz",
|
| 186 |
+
"repo": "nerijs/lego-brickheadz-xl",
|
| 187 |
+
"trigger_word": "lego brickheadz",
|
| 188 |
+
"weights": "legobrickheadz-v1.0-000004.safetensors",
|
| 189 |
+
"is_compatible": true
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"image": "images/lego-minifig-xl.jpeg",
|
| 193 |
+
"title": "Lego Minifig XL",
|
| 194 |
+
"repo": "nerijs/lego-minifig-xl",
|
| 195 |
+
"trigger_word": "lego minifig",
|
| 196 |
+
"weights": "legominifig-v1.0-000003.safetensors",
|
| 197 |
+
"is_compatible": true
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"image": "images/jojoso1.jpg",
|
| 201 |
+
"title": "JoJo's Bizarre style",
|
| 202 |
+
"repo": "Norod78/SDXL-jojoso_style-Lora",
|
| 203 |
+
"trigger_word": "jojoso style",
|
| 204 |
+
"weights": "SDXL-jojoso_style-Lora-r8.safetensors",
|
| 205 |
+
"is_compatible": true
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"image": "images/pikachu.webp",
|
| 209 |
+
"title": "Pikachu XL",
|
| 210 |
+
"repo": "TheLastBen/Pikachu_SDXL",
|
| 211 |
+
"trigger_word": "pikachu",
|
| 212 |
+
"weights": "pikachu.safetensors",
|
| 213 |
+
"is_compatible": true
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"image": "images/LogoRedmond-LogoLoraForSDXL.jpeg",
|
| 217 |
+
"title": "Logo.Redmond",
|
| 218 |
+
"repo": "artificialguybr/LogoRedmond-LogoLoraForSDXL",
|
| 219 |
+
"trigger_word": "LogoRedAF",
|
| 220 |
+
"weights": "LogoRedmond_LogoRedAF.safetensors",
|
| 221 |
+
"is_compatible": true
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"image": "https://huggingface.co/Norod78/SDXL-StickerSheet-Lora/resolve/main/Examples/00073-20230831113700-7780-Cthulhu%20StickerSheet%20%20_lora_SDXL-StickerSheet-Lora_1_%2C%20based%20on%20H.P%20Lovecraft%20stories%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sharp.jpg",
|
| 225 |
+
"title": "Sticker Sheet",
|
| 226 |
+
"repo": "Norod78/SDXL-StickerSheet-Lora",
|
| 227 |
+
"trigger_word": "StickerSheet",
|
| 228 |
+
"weights": "SDXL-StickerSheet-Lora.safetensors",
|
| 229 |
+
"is_compatible": true
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"image": "images/LineAni.Redmond.png",
|
| 233 |
+
"title": "LinearManga.Redmond",
|
| 234 |
+
"repo": "artificialguybr/LineAniRedmond-LinearMangaSDXL",
|
| 235 |
+
"trigger_word": "LineAniAF",
|
| 236 |
+
"weights": "LineAniRedmond-LineAniAF.safetensors",
|
| 237 |
+
"is_compatible": true
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"image": "images/josef_koudelka.webp",
|
| 241 |
+
"title": "Josef Koudelka Style",
|
| 242 |
+
"repo": "TheLastBen/Josef_Koudelka_Style_SDXL",
|
| 243 |
+
"trigger_word": "by josef koudelka",
|
| 244 |
+
"weights": "koud.safetensors",
|
| 245 |
+
"is_compatible": true
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"image": "https://huggingface.co/goofyai/Leonardo_Ai_Style_Illustration/resolve/main/leo-2.png",
|
| 249 |
+
"title": "Leonardo Style",
|
| 250 |
+
"repo": "goofyai/Leonardo_Ai_Style_Illustration",
|
| 251 |
+
"trigger_word": "leonardo style",
|
| 252 |
+
"weights": "leonardo_illustration.safetensors",
|
| 253 |
+
"is_compatible": true
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"image":"https://huggingface.co/Norod78/SDXL-simpstyle-Lora/resolve/main/Examples/00006-20230820150225-558-the%20girl%20with%20a%20pearl%20earring%20by%20johannes%20vermeer%20simpstyle%20_lora_SDXL-simpstyle-Lora_1_%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sh.jpg",
|
| 257 |
+
"title": "SimpStyle",
|
| 258 |
+
"repo": "Norod78/SDXL-simpstyle-Lora",
|
| 259 |
+
"trigger_word":"simpstyle",
|
| 260 |
+
"weights": "SDXL-simpstyle-Lora-r8.safetensors",
|
| 261 |
+
"is_compatible": true
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"image":"https://huggingface.co/artificialguybr/StoryBookRedmond/resolve/main/00162-1569823442.png",
|
| 265 |
+
"title": "Storybook.Redmond",
|
| 266 |
+
"repo": "artificialguybr/StoryBookRedmond",
|
| 267 |
+
"trigger_word":"KidsRedmAF",
|
| 268 |
+
"weights": "StoryBookRedmond-KidsRedmAF.safetensors",
|
| 269 |
+
"is_compatible": true
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"image": "https://huggingface.co/chillpixel/blacklight-makeup-sdxl-lora/resolve/main/preview.png",
|
| 273 |
+
"title": "Blacklight Makeup",
|
| 274 |
+
"repo":"chillpixel/blacklight-makeup-sdxl-lora",
|
| 275 |
+
"trigger_word": "with blacklight makeup",
|
| 276 |
+
"weights": "pytorch_lora_weights.bin",
|
| 277 |
+
"is_compatible": true
|
| 278 |
+
}
|
| 279 |
+
]
|
share_btn.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
| 2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
| 3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
| 4 |
+
</svg>"""
|
| 5 |
+
|
| 6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
| 7 |
+
style="color: #ffffff;
|
| 8 |
+
"
|
| 9 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
| 10 |
+
|
| 11 |
+
share_js = """async () => {
|
| 12 |
+
async function uploadFile(file){
|
| 13 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
| 14 |
+
const response = await fetch(UPLOAD_URL, {
|
| 15 |
+
method: 'POST',
|
| 16 |
+
headers: {
|
| 17 |
+
'Content-Type': file.type,
|
| 18 |
+
'X-Requested-With': 'XMLHttpRequest',
|
| 19 |
+
},
|
| 20 |
+
body: file, /// <- File inherits from Blob
|
| 21 |
+
});
|
| 22 |
+
const url = await response.text();
|
| 23 |
+
return url;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
async function getInputImgFile(imgEl){
|
| 27 |
+
const res = await fetch(imgEl.src);
|
| 28 |
+
const blob = await res.blob();
|
| 29 |
+
const imgId = Date.now() % 200;
|
| 30 |
+
const isPng = imgEl.src.startsWith(`data:image/png`);
|
| 31 |
+
if(isPng){
|
| 32 |
+
const fileName = `sd-perception-${{imgId}}.png`;
|
| 33 |
+
return new File([blob], fileName, { type: 'image/png' });
|
| 34 |
+
}else{
|
| 35 |
+
const fileName = `sd-perception-${{imgId}}.jpg`;
|
| 36 |
+
return new File([blob], fileName, { type: 'image/jpeg' });
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
| 41 |
+
const selectedLoRA = gradioEl.querySelector('#selected_lora').innerHTML;
|
| 42 |
+
const inputPrompt = gradioEl.querySelector('#prompt input').value;
|
| 43 |
+
const outputImgEl = gradioEl.querySelector('#result-image img');
|
| 44 |
+
|
| 45 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
| 46 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
| 47 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
| 48 |
+
|
| 49 |
+
shareBtnEl.style.pointerEvents = 'none';
|
| 50 |
+
shareIconEl.style.display = 'none';
|
| 51 |
+
loadingIconEl.style.removeProperty('display');
|
| 52 |
+
|
| 53 |
+
const inputFile = await getInputImgFile(outputImgEl);
|
| 54 |
+
const urlInputImg = await uploadFile(inputFile);
|
| 55 |
+
|
| 56 |
+
const descriptionMd = `
|
| 57 |
+
|
| 58 |
+
${selectedLoRA}
|
| 59 |
+
|
| 60 |
+
### Prompt
|
| 61 |
+
${inputPrompt}
|
| 62 |
+
|
| 63 |
+
#### Generated Image:
|
| 64 |
+
<img src="${urlInputImg}" />
|
| 65 |
+
`;
|
| 66 |
+
const params = new URLSearchParams({
|
| 67 |
+
title: inputPrompt,
|
| 68 |
+
description: descriptionMd,
|
| 69 |
+
preview: true
|
| 70 |
+
});
|
| 71 |
+
const paramsStr = params.toString();
|
| 72 |
+
window.open(`https://huggingface.co/spaces/multimodalart/LoraTheExplorer/discussions/new?${paramsStr}`, '_blank');
|
| 73 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
| 74 |
+
shareIconEl.style.removeProperty('display');
|
| 75 |
+
loadingIconEl.style.display = 'none';
|
| 76 |
+
}"""
|