menard multimodalart HF Staff commited on
Commit
45732de
·
0 Parent(s):

Duplicate from multimodalart/LoraTheExplorer

Browse files

Co-authored-by: Apolinário from multimodal AI art <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LoRA the Explorer
3
+ emoji: 🔎 🖼️
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: multimodalart/LoraTheExplorer
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
+ from huggingface_hub import hf_hub_download
5
+ from share_btn import community_icon_html, loading_icon_html, share_js
6
+ from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
7
+ import lora
8
+ from time import sleep
9
+ import copy
10
+ import json
11
+ import gc
12
+
13
+ with open("sdxl_loras.json", "r") as file:
14
+ data = json.load(file)
15
+ sdxl_loras = [
16
+ {
17
+ "image": item["image"],
18
+ "title": item["title"],
19
+ "repo": item["repo"],
20
+ "trigger_word": item["trigger_word"],
21
+ "weights": item["weights"],
22
+ "is_compatible": item["is_compatible"],
23
+ "is_pivotal": item.get("is_pivotal", False),
24
+ "text_embedding_weights": item.get("text_embedding_weights", None),
25
+ "is_nc": item.get("is_nc", False)
26
+ }
27
+ for item in data
28
+ ]
29
+ print(sdxl_loras)
30
+ saved_names = [
31
+ hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
32
+ ]
33
+
34
+ device = "cuda" # replace this to `mps` if on a MacOS Silicon
35
+
36
+ vae = AutoencoderKL.from_pretrained(
37
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
38
+ )
39
+ pipe = StableDiffusionXLPipeline.from_pretrained(
40
+ "stabilityai/stable-diffusion-xl-base-1.0",
41
+ vae=vae,
42
+ torch_dtype=torch.float16,
43
+ ).to("cpu")
44
+ original_pipe = copy.deepcopy(pipe)
45
+ pipe.to(device)
46
+
47
+ last_lora = ""
48
+ last_merged = False
49
+
50
+
51
+ def update_selection(selected_state: gr.SelectData):
52
+ lora_repo = sdxl_loras[selected_state.index]["repo"]
53
+ instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
54
+ new_placeholder = "Type a prompt. This LoRA applies for all prompts, no need for a trigger word" if instance_prompt == "" else "Type a prompt to use your selected LoRA"
55
+ weight_name = sdxl_loras[selected_state.index]["weights"]
56
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
57
+ is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
58
+ is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
59
+
60
+ use_with_diffusers = f'''
61
+ ## Using [`{lora_repo}`](https://huggingface.co/{lora_repo})
62
+
63
+ ## Use it with diffusers:
64
+ '''
65
+ if is_compatible:
66
+ use_with_diffusers += f'''
67
+ from diffusers import StableDiffusionXLPipeline
68
+ import torch
69
+
70
+ model_path = "stabilityai/stable-diffusion-xl-base-1.0"
71
+ pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
72
+ pipe.to("cuda")
73
+ pipe.load_lora_weights("{lora_repo}", weight_name="{weight_name}")
74
+
75
+ prompt = "{instance_prompt}..."
76
+ lora_scale= 0.9
77
+ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={{"scale": lora_scale}}).images[0]
78
+ image.save("image.png")
79
+ '''
80
+ elif not is_pivotal:
81
+ use_with_diffusers += "This LoRA is not compatible with diffusers natively yet. But you can still use it on diffusers with `bmaltais/kohya_ss` LoRA class, check out this [Google Colab](https://colab.research.google.com/drive/14aEJsKdEQ9_kyfsiV6JDok799kxPul0j )"
82
+ else:
83
+ use_with_diffusers += f"This LoRA is not compatible with diffusers natively yet. But you can still use it on diffusers with sdxl-cog `TokenEmbeddingsHandler` class, check out the [model repo](https://huggingface.co/{lora_repo}#inference-with-🧨-diffusers)"
84
+ use_with_uis = f'''
85
+ ## Use it with Comfy UI, Invoke AI, SD.Next, AUTO1111:
86
+
87
+ ### Download the `*.safetensors` weights of [here](https://huggingface.co/{lora_repo}/resolve/main/{weight_name})
88
+
89
+ - [ComfyUI guide](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
90
+ - [Invoke AI guide](https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/?h=lora#using-loras)
91
+ - [SD.Next guide](https://github.com/vladmandic/automatic)
92
+ - [AUTOMATIC1111 guide](https://stable-diffusion-art.com/lora/)
93
+ '''
94
+ return (
95
+ updated_text,
96
+ instance_prompt,
97
+ gr.update(placeholder=new_placeholder),
98
+ selected_state,
99
+ use_with_diffusers,
100
+ use_with_uis,
101
+ )
102
+
103
+
104
+ def check_selected(selected_state):
105
+ if not selected_state:
106
+ raise gr.Error("You must select a LoRA")
107
+
108
+ def merge_incompatible_lora(full_path_lora, lora_scale):
109
+ for weights_file in [full_path_lora]:
110
+ if ";" in weights_file:
111
+ weights_file, multiplier = weights_file.split(";")
112
+ multiplier = float(multiplier)
113
+ else:
114
+ multiplier = lora_scale
115
+
116
+ lora_model, weights_sd = lora.create_network_from_weights(
117
+ multiplier,
118
+ full_path_lora,
119
+ pipe.vae,
120
+ pipe.text_encoder,
121
+ pipe.unet,
122
+ for_inference=True,
123
+ )
124
+ lora_model.merge_to(
125
+ pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
126
+ )
127
+ del weights_sd
128
+ del lora_model
129
+ gc.collect()
130
+
131
+ def run_lora(prompt, negative, lora_scale, selected_state):
132
+ global last_lora, last_merged, pipe
133
+
134
+ if negative == "":
135
+ negative = None
136
+
137
+ if not selected_state:
138
+ raise gr.Error("You must select a LoRA")
139
+ repo_name = sdxl_loras[selected_state.index]["repo"]
140
+ weight_name = sdxl_loras[selected_state.index]["weights"]
141
+ full_path_lora = saved_names[selected_state.index]
142
+ cross_attention_kwargs = None
143
+ if last_lora != repo_name:
144
+ if last_merged:
145
+ del pipe
146
+ gc.collect()
147
+ pipe = copy.deepcopy(original_pipe)
148
+ pipe.to(device)
149
+ else:
150
+ pipe.unload_lora_weights()
151
+ is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
152
+ if is_compatible:
153
+ pipe.load_lora_weights(full_path_lora)
154
+ cross_attention_kwargs = {"scale": lora_scale}
155
+ else:
156
+ is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
157
+ if(is_pivotal):
158
+
159
+ pipe.load_lora_weights(full_path_lora)
160
+ cross_attention_kwargs = {"scale": lora_scale}
161
+
162
+ #Add the textual inversion embeddings from pivotal tuning models
163
+ text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
164
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
165
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
166
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
167
+ embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
168
+ embhandler.load_embeddings(embedding_path)
169
+ else:
170
+ merge_incompatible_lora(full_path_lora, lora_scale)
171
+ last_merged = True
172
+
173
+ image = pipe(
174
+ prompt=prompt,
175
+ negative_prompt=negative,
176
+ width=768,
177
+ height=768,
178
+ num_inference_steps=20,
179
+ guidance_scale=7.5,
180
+ cross_attention_kwargs=cross_attention_kwargs,
181
+ ).images[0]
182
+ last_lora = repo_name
183
+ gc.collect()
184
+ return image, gr.update(visible=True)
185
+
186
+
187
+ with gr.Blocks(css="custom.css") as demo:
188
+ title = gr.HTML(
189
+ """<h1><img src="https://i.imgur.com/vT48NAO.png" alt="LoRA"> LoRA the Explorer</h1>""",
190
+ elem_id="title",
191
+ )
192
+ selected_state = gr.State()
193
+ with gr.Row():
194
+ gallery = gr.Gallery(
195
+ value=[(item["image"], item["title"]) for item in sdxl_loras],
196
+ label="SDXL LoRA Gallery",
197
+ allow_preview=False,
198
+ columns=3,
199
+ elem_id="gallery",
200
+ show_share_button=False
201
+ )
202
+ with gr.Column():
203
+ prompt_title = gr.Markdown(
204
+ value="### Click on a LoRA in the gallery to select it",
205
+ visible=True,
206
+ elem_id="selected_lora",
207
+ )
208
+ with gr.Row():
209
+ prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, placeholder="Type a prompt after selecting a LoRA", elem_id="prompt")
210
+ button = gr.Button("Run", elem_id="run_button")
211
+ with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
212
+ community_icon = gr.HTML(community_icon_html)
213
+ loading_icon = gr.HTML(loading_icon_html)
214
+ share_button = gr.Button("Share to community", elem_id="share-btn")
215
+ result = gr.Image(
216
+ interactive=False, label="Generated Image", elem_id="result-image"
217
+ )
218
+ with gr.Accordion("Advanced options", open=False):
219
+ negative = gr.Textbox(label="Negative Prompt")
220
+ weight = gr.Slider(0, 10, value=1, step=0.1, label="LoRA weight")
221
+
222
+ with gr.Column(elem_id="extra_info"):
223
+ with gr.Accordion(
224
+ "Use it with: 🧨 diffusers, ComfyUI, Invoke AI, SD.Next, AUTO1111",
225
+ open=False,
226
+ elem_id="accordion",
227
+ ):
228
+ with gr.Row():
229
+ use_diffusers = gr.Markdown("""## Select a LoRA first 🤗""")
230
+ use_uis = gr.Markdown()
231
+ with gr.Accordion("Submit a LoRA! 📥", open=False):
232
+ submit_title = gr.Markdown(
233
+ "### Streamlined submission coming soon! Until then [suggest your LoRA in the community tab](https://huggingface.co/spaces/multimodalart/LoraTheExplorer/discussions) 🤗"
234
+ )
235
+ with gr.Box(elem_id="soon"):
236
+ submit_source = gr.Radio(
237
+ ["Hugging Face", "CivitAI"],
238
+ label="LoRA source",
239
+ value="Hugging Face",
240
+ )
241
+ with gr.Row():
242
+ submit_source_hf = gr.Textbox(
243
+ label="Hugging Face Model Repo",
244
+ info="In the format `username/model_id`",
245
+ )
246
+ submit_safetensors_hf = gr.Textbox(
247
+ label="Safetensors filename",
248
+ info="The filename `*.safetensors` in the model repo",
249
+ )
250
+ with gr.Row():
251
+ submit_trigger_word_hf = gr.Textbox(label="Trigger word")
252
+ submit_image = gr.Image(
253
+ label="Example image (optional if the repo already contains images)"
254
+ )
255
+ submit_button = gr.Button("Submit!")
256
+ submit_disclaimer = gr.Markdown(
257
+ "This is a curated gallery by me, [apolinário (multimodal.art)](https://twitter.com/multimodalart). I'll try to include as many cool LoRAs as they are submitted! You can [duplicate this Space](https://huggingface.co/spaces/multimodalart/LoraTheExplorer?duplicate=true) to use it privately, and add your own LoRAs by editing `sdxl_loras.json` in the Files tab of your private space."
258
+ )
259
+
260
+ gallery.select(
261
+ update_selection,
262
+ outputs=[prompt_title, prompt, prompt, selected_state, use_diffusers, use_uis],
263
+ queue=False,
264
+ show_progress=False,
265
+ )
266
+ prompt.submit(
267
+ fn=check_selected,
268
+ inputs=[selected_state],
269
+ queue=False,
270
+ show_progress=False
271
+ ).success(
272
+ fn=run_lora,
273
+ inputs=[prompt, negative, weight, selected_state],
274
+ outputs=[result, share_group],
275
+ )
276
+ button.click(
277
+ fn=check_selected,
278
+ inputs=[selected_state],
279
+ queue=False,
280
+ show_progress=False
281
+ ).success(
282
+ fn=run_lora,
283
+ inputs=[prompt, negative, weight, selected_state],
284
+ outputs=[result, share_group],
285
+ )
286
+ share_button.click(None, [], [], _js=share_js)
287
+
288
+ demo.queue(max_size=20)
289
+ demo.launch()
cog_sdxl_dataset_and_utils.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset_and_utils.py file taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
2
+ import os
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import PIL
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+ from safetensors.torch import save_file
14
+ from torch.utils.data import Dataset
15
+ from transformers import AutoTokenizer, PretrainedConfig
16
+
17
+
18
+ def prepare_image(
19
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
20
+ ) -> torch.Tensor:
21
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
22
+ arr = np.array(pil_image.convert("RGB"))
23
+ arr = arr.astype(np.float32) / 127.5 - 1
24
+ arr = np.transpose(arr, [2, 0, 1])
25
+ image = torch.from_numpy(arr).unsqueeze(0)
26
+ return image
27
+
28
+
29
+ def prepare_mask(
30
+ pil_image: PIL.Image.Image, w: int = 512, h: int = 512
31
+ ) -> torch.Tensor:
32
+ pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
33
+ arr = np.array(pil_image.convert("L"))
34
+ arr = arr.astype(np.float32) / 255.0
35
+ arr = np.expand_dims(arr, 0)
36
+ image = torch.from_numpy(arr).unsqueeze(0)
37
+ return image
38
+
39
+
40
+ class PreprocessedDataset(Dataset):
41
+ def __init__(
42
+ self,
43
+ csv_path: str,
44
+ tokenizer_1,
45
+ tokenizer_2,
46
+ vae_encoder,
47
+ text_encoder_1=None,
48
+ text_encoder_2=None,
49
+ do_cache: bool = False,
50
+ size: int = 512,
51
+ text_dropout: float = 0.0,
52
+ scale_vae_latents: bool = True,
53
+ substitute_caption_map: Dict[str, str] = {},
54
+ ):
55
+ super().__init__()
56
+
57
+ self.data = pd.read_csv(csv_path)
58
+ self.csv_path = csv_path
59
+
60
+ self.caption = self.data["caption"]
61
+ # make it lowercase
62
+ self.caption = self.caption.str.lower()
63
+ for key, value in substitute_caption_map.items():
64
+ self.caption = self.caption.str.replace(key.lower(), value)
65
+
66
+ self.image_path = self.data["image_path"]
67
+
68
+ if "mask_path" not in self.data.columns:
69
+ self.mask_path = None
70
+ else:
71
+ self.mask_path = self.data["mask_path"]
72
+
73
+ if text_encoder_1 is None:
74
+ self.return_text_embeddings = False
75
+ else:
76
+ self.text_encoder_1 = text_encoder_1
77
+ self.text_encoder_2 = text_encoder_2
78
+ self.return_text_embeddings = True
79
+ assert (
80
+ NotImplementedError
81
+ ), "Preprocessing Text Encoder is not implemented yet"
82
+
83
+ self.tokenizer_1 = tokenizer_1
84
+ self.tokenizer_2 = tokenizer_2
85
+
86
+ self.vae_encoder = vae_encoder
87
+ self.scale_vae_latents = scale_vae_latents
88
+ self.text_dropout = text_dropout
89
+
90
+ self.size = size
91
+
92
+ if do_cache:
93
+ self.vae_latents = []
94
+ self.tokens_tuple = []
95
+ self.masks = []
96
+
97
+ self.do_cache = True
98
+
99
+ print("Captions to train on: ")
100
+ for idx in range(len(self.data)):
101
+ token, vae_latent, mask = self._process(idx)
102
+ self.vae_latents.append(vae_latent)
103
+ self.tokens_tuple.append(token)
104
+ self.masks.append(mask)
105
+
106
+ del self.vae_encoder
107
+
108
+ else:
109
+ self.do_cache = False
110
+
111
+ @torch.no_grad()
112
+ def _process(
113
+ self, idx: int
114
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
115
+ image_path = self.image_path[idx]
116
+ image_path = os.path.join(os.path.dirname(self.csv_path), image_path)
117
+
118
+ image = PIL.Image.open(image_path).convert("RGB")
119
+ image = prepare_image(image, self.size, self.size).to(
120
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
121
+ )
122
+
123
+ caption = self.caption[idx]
124
+
125
+ print(caption)
126
+
127
+ # tokenizer_1
128
+ ti1 = self.tokenizer_1(
129
+ caption,
130
+ padding="max_length",
131
+ max_length=77,
132
+ truncation=True,
133
+ add_special_tokens=True,
134
+ return_tensors="pt",
135
+ ).input_ids
136
+
137
+ ti2 = self.tokenizer_2(
138
+ caption,
139
+ padding="max_length",
140
+ max_length=77,
141
+ truncation=True,
142
+ add_special_tokens=True,
143
+ return_tensors="pt",
144
+ ).input_ids
145
+
146
+ vae_latent = self.vae_encoder.encode(image).latent_dist.sample()
147
+
148
+ if self.scale_vae_latents:
149
+ vae_latent = vae_latent * self.vae_encoder.config.scaling_factor
150
+
151
+ if self.mask_path is None:
152
+ mask = torch.ones_like(
153
+ vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
154
+ )
155
+
156
+ else:
157
+ mask_path = self.mask_path[idx]
158
+ mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path)
159
+
160
+ mask = PIL.Image.open(mask_path)
161
+ mask = prepare_mask(mask, self.size, self.size).to(
162
+ dtype=self.vae_encoder.dtype, device=self.vae_encoder.device
163
+ )
164
+
165
+ mask = torch.nn.functional.interpolate(
166
+ mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest"
167
+ )
168
+ mask = mask.repeat(1, vae_latent.shape[1], 1, 1)
169
+
170
+ assert len(mask.shape) == 4 and len(vae_latent.shape) == 4
171
+
172
+ return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze()
173
+
174
+ def __len__(self) -> int:
175
+ return len(self.data)
176
+
177
+ def atidx(
178
+ self, idx: int
179
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
180
+ if self.do_cache:
181
+ return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx]
182
+ else:
183
+ return self._process(idx)
184
+
185
+ def __getitem__(
186
+ self, idx: int
187
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
188
+ token, vae_latent, mask = self.atidx(idx)
189
+ return token, vae_latent, mask
190
+
191
+
192
+ def import_model_class_from_model_name_or_path(
193
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
194
+ ):
195
+ text_encoder_config = PretrainedConfig.from_pretrained(
196
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
197
+ )
198
+ model_class = text_encoder_config.architectures[0]
199
+
200
+ if model_class == "CLIPTextModel":
201
+ from transformers import CLIPTextModel
202
+
203
+ return CLIPTextModel
204
+ elif model_class == "CLIPTextModelWithProjection":
205
+ from transformers import CLIPTextModelWithProjection
206
+
207
+ return CLIPTextModelWithProjection
208
+ else:
209
+ raise ValueError(f"{model_class} is not supported.")
210
+
211
+
212
+ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
213
+ tokenizer_one = AutoTokenizer.from_pretrained(
214
+ pretrained_model_name_or_path,
215
+ subfolder="tokenizer",
216
+ revision=revision,
217
+ use_fast=False,
218
+ )
219
+ tokenizer_two = AutoTokenizer.from_pretrained(
220
+ pretrained_model_name_or_path,
221
+ subfolder="tokenizer_2",
222
+ revision=revision,
223
+ use_fast=False,
224
+ )
225
+
226
+ # Load scheduler and models
227
+ noise_scheduler = DDPMScheduler.from_pretrained(
228
+ pretrained_model_name_or_path, subfolder="scheduler"
229
+ )
230
+ # import correct text encoder classes
231
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
232
+ pretrained_model_name_or_path, revision
233
+ )
234
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
235
+ pretrained_model_name_or_path, revision, subfolder="text_encoder_2"
236
+ )
237
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
238
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
239
+ )
240
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
241
+ pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision
242
+ )
243
+
244
+ vae = AutoencoderKL.from_pretrained(
245
+ pretrained_model_name_or_path, subfolder="vae", revision=revision
246
+ )
247
+ unet = UNet2DConditionModel.from_pretrained(
248
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
249
+ )
250
+
251
+ vae.requires_grad_(False)
252
+ text_encoder_one.requires_grad_(False)
253
+ text_encoder_two.requires_grad_(False)
254
+
255
+ unet.to(device, dtype=weight_dtype)
256
+ vae.to(device, dtype=torch.float32)
257
+ text_encoder_one.to(device, dtype=weight_dtype)
258
+ text_encoder_two.to(device, dtype=weight_dtype)
259
+
260
+ return (
261
+ tokenizer_one,
262
+ tokenizer_two,
263
+ noise_scheduler,
264
+ text_encoder_one,
265
+ text_encoder_two,
266
+ vae,
267
+ unet,
268
+ )
269
+
270
+
271
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
272
+ """
273
+ Returns:
274
+ a state dict containing just the attention processor parameters.
275
+ """
276
+ attn_processors = unet.attn_processors
277
+
278
+ attn_processors_state_dict = {}
279
+
280
+ for attn_processor_key, attn_processor in attn_processors.items():
281
+ for parameter_key, parameter in attn_processor.state_dict().items():
282
+ attn_processors_state_dict[
283
+ f"{attn_processor_key}.{parameter_key}"
284
+ ] = parameter
285
+
286
+ return attn_processors_state_dict
287
+
288
+
289
+ class TokenEmbeddingsHandler:
290
+ def __init__(self, text_encoders, tokenizers):
291
+ self.text_encoders = text_encoders
292
+ self.tokenizers = tokenizers
293
+
294
+ self.train_ids: Optional[torch.Tensor] = None
295
+ self.inserting_toks: Optional[List[str]] = None
296
+ self.embeddings_settings = {}
297
+
298
+ def initialize_new_tokens(self, inserting_toks: List[str]):
299
+ idx = 0
300
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
301
+ assert isinstance(
302
+ inserting_toks, list
303
+ ), "inserting_toks should be a list of strings."
304
+ assert all(
305
+ isinstance(tok, str) for tok in inserting_toks
306
+ ), "All elements in inserting_toks should be strings."
307
+
308
+ self.inserting_toks = inserting_toks
309
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
310
+ tokenizer.add_special_tokens(special_tokens_dict)
311
+ text_encoder.resize_token_embeddings(len(tokenizer))
312
+
313
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
314
+
315
+ # random initialization of new tokens
316
+
317
+ std_token_embedding = (
318
+ text_encoder.text_model.embeddings.token_embedding.weight.data.std()
319
+ )
320
+
321
+ print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}")
322
+
323
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
324
+ self.train_ids
325
+ ] = (
326
+ torch.randn(
327
+ len(self.train_ids), text_encoder.text_model.config.hidden_size
328
+ )
329
+ .to(device=self.device)
330
+ .to(dtype=self.dtype)
331
+ * std_token_embedding
332
+ )
333
+ self.embeddings_settings[
334
+ f"original_embeddings_{idx}"
335
+ ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone()
336
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
337
+
338
+ inu = torch.ones((len(tokenizer),), dtype=torch.bool)
339
+ inu[self.train_ids] = False
340
+
341
+ self.embeddings_settings[f"index_no_updates_{idx}"] = inu
342
+
343
+ print(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
344
+
345
+ idx += 1
346
+
347
+ def save_embeddings(self, file_path: str):
348
+ assert (
349
+ self.train_ids is not None
350
+ ), "Initialize new tokens before saving embeddings."
351
+ tensors = {}
352
+ for idx, text_encoder in enumerate(self.text_encoders):
353
+ assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[
354
+ 0
355
+ ] == len(self.tokenizers[0]), "Tokenizers should be the same."
356
+ new_token_embeddings = (
357
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
358
+ self.train_ids
359
+ ]
360
+ )
361
+ tensors[f"text_encoders_{idx}"] = new_token_embeddings
362
+
363
+ save_file(tensors, file_path)
364
+
365
+ @property
366
+ def dtype(self):
367
+ return self.text_encoders[0].dtype
368
+
369
+ @property
370
+ def device(self):
371
+ return self.text_encoders[0].device
372
+
373
+ def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
374
+ # Assuming new tokens are of the format <s_i>
375
+ self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
376
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
377
+ tokenizer.add_special_tokens(special_tokens_dict)
378
+ text_encoder.resize_token_embeddings(len(tokenizer))
379
+
380
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
381
+ assert self.train_ids is not None, "New tokens could not be converted to IDs."
382
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
383
+ self.train_ids
384
+ ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
385
+
386
+ @torch.no_grad()
387
+ def retract_embeddings(self):
388
+ for idx, text_encoder in enumerate(self.text_encoders):
389
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
390
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
391
+ index_no_updates
392
+ ] = (
393
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
394
+ .to(device=text_encoder.device)
395
+ .to(dtype=text_encoder.dtype)
396
+ )
397
+
398
+ # for the parts that were updated, we need to normalize them
399
+ # to have the same std as before
400
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
401
+
402
+ index_updates = ~index_no_updates
403
+ new_embeddings = (
404
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
405
+ index_updates
406
+ ]
407
+ )
408
+ off_ratio = std_token_embedding / new_embeddings.std()
409
+
410
+ new_embeddings = new_embeddings * (off_ratio**0.1)
411
+ text_encoder.text_model.embeddings.token_embedding.weight.data[
412
+ index_updates
413
+ ] = new_embeddings
414
+
415
+ def load_embeddings(self, file_path: str):
416
+ with safe_open(file_path, framework="pt", device=self.device.type) as f:
417
+ for idx in range(len(self.text_encoders)):
418
+ text_encoder = self.text_encoders[idx]
419
+ tokenizer = self.tokenizers[idx]
420
+
421
+ loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
422
+ self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
custom.css ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #title{text-align: center;}
2
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
3
+ #title img{width: 100px; margin-right: 0.5em}
4
+ #prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
5
+ #run_button{position:absolute;margin-top: 11px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
6
+ border-top-left-radius: 0px;}
7
+ #gallery{display:flex}
8
+ #gallery .grid-wrap{min-height: 100%;}
9
+ #accordion code{word-break: break-all;word-wrap: break-word;white-space: pre-wrap}
10
+ #soon{opacity: 0.55; pointer-events: none}
11
+ #soon button{width: 100%}
12
+ #share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
13
+ div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
14
+ #share-btn-container:hover {background-color: #060606}
15
+ #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
16
+ #share-btn * {all: unset}
17
+ #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
18
+ #share-btn-container .wrap {display: none !important}
19
+ #share-btn-container.hidden {display: none!important}
20
+ #extra_info{margin-top: 1em}
21
+ .pending .min {min-height: auto}
22
+
23
+ @media (max-width: 527px) {
24
+ #title h1{font-size: 2.2em}
25
+ #title img{width: 80px;}
26
+ #gallery {max-height: 370px}
27
+ }
images/3d_style_4.jpeg ADDED
images/LineAni.Redmond.png ADDED
images/LogoRedmond-LogoLoraForSDXL.jpeg ADDED
images/ToyRedmond-ToyLoraForSDXL10.png ADDED
images/corgi_brick.jpeg ADDED
images/crayon.png ADDED
images/dog.png ADDED
images/embroid.png ADDED
images/jojoso1.jpg ADDED
images/josef_koudelka.webp ADDED
images/lego-minifig-xl.jpeg ADDED
images/papercut_SDXL.jpeg ADDED
images/pikachu.webp ADDED
images/pixel-art-xl.jpeg ADDED
images/riding-min.jpg ADDED
images/the_fish.jpg ADDED
images/uglysonic.webp ADDED
images/voxel-xl-lora.png ADDED
images/watercolor.png ADDED
images/william_eggleston.webp ADDED
lora.png ADDED
lora.py ADDED
@@ -0,0 +1,1222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module taken from https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
+ import numpy as np
12
+ import torch
13
+ import re
14
+
15
+
16
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
+
18
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
19
+
20
+
21
+ class LoRAModule(torch.nn.Module):
22
+ """
23
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ lora_name,
29
+ org_module: torch.nn.Module,
30
+ multiplier=1.0,
31
+ lora_dim=4,
32
+ alpha=1,
33
+ dropout=None,
34
+ rank_dropout=None,
35
+ module_dropout=None,
36
+ ):
37
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
38
+ super().__init__()
39
+ self.lora_name = lora_name
40
+
41
+ if org_module.__class__.__name__ == "Conv2d":
42
+ in_dim = org_module.in_channels
43
+ out_dim = org_module.out_channels
44
+ else:
45
+ in_dim = org_module.in_features
46
+ out_dim = org_module.out_features
47
+
48
+ # if limit_rank:
49
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
50
+ # if self.lora_dim != lora_dim:
51
+ # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
52
+ # else:
53
+ self.lora_dim = lora_dim
54
+
55
+ if org_module.__class__.__name__ == "Conv2d":
56
+ kernel_size = org_module.kernel_size
57
+ stride = org_module.stride
58
+ padding = org_module.padding
59
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
60
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
61
+ else:
62
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
63
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
64
+
65
+ if type(alpha) == torch.Tensor:
66
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
67
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
68
+ self.scale = alpha / self.lora_dim
69
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
70
+
71
+ # same as microsoft's
72
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
73
+ torch.nn.init.zeros_(self.lora_up.weight)
74
+
75
+ self.multiplier = multiplier
76
+ self.org_module = org_module # remove in applying
77
+ self.dropout = dropout
78
+ self.rank_dropout = rank_dropout
79
+ self.module_dropout = module_dropout
80
+
81
+ def apply_to(self):
82
+ self.org_forward = self.org_module.forward
83
+ self.org_module.forward = self.forward
84
+ del self.org_module
85
+
86
+ def forward(self, x):
87
+ org_forwarded = self.org_forward(x)
88
+
89
+ # module dropout
90
+ if self.module_dropout is not None and self.training:
91
+ if torch.rand(1) < self.module_dropout:
92
+ return org_forwarded
93
+
94
+ lx = self.lora_down(x)
95
+
96
+ # normal dropout
97
+ if self.dropout is not None and self.training:
98
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
99
+
100
+ # rank dropout
101
+ if self.rank_dropout is not None and self.training:
102
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
103
+ if len(lx.size()) == 3:
104
+ mask = mask.unsqueeze(1) # for Text Encoder
105
+ elif len(lx.size()) == 4:
106
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
107
+ lx = lx * mask
108
+
109
+ # scaling for rank dropout: treat as if the rank is changed
110
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
111
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
112
+ else:
113
+ scale = self.scale
114
+
115
+ lx = self.lora_up(lx)
116
+
117
+ return org_forwarded + lx * self.multiplier * scale
118
+
119
+
120
+ class LoRAInfModule(LoRAModule):
121
+ def __init__(
122
+ self,
123
+ lora_name,
124
+ org_module: torch.nn.Module,
125
+ multiplier=1.0,
126
+ lora_dim=4,
127
+ alpha=1,
128
+ **kwargs,
129
+ ):
130
+ # no dropout for inference
131
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
132
+
133
+ self.org_module_ref = [org_module] # 後から参照できるように
134
+ self.enabled = True
135
+
136
+ # check regional or not by lora_name
137
+ self.text_encoder = False
138
+ if lora_name.startswith("lora_te_"):
139
+ self.regional = False
140
+ self.use_sub_prompt = True
141
+ self.text_encoder = True
142
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
143
+ self.regional = False
144
+ self.use_sub_prompt = True
145
+ elif "time_emb" in lora_name:
146
+ self.regional = False
147
+ self.use_sub_prompt = False
148
+ else:
149
+ self.regional = True
150
+ self.use_sub_prompt = False
151
+
152
+ self.network: LoRANetwork = None
153
+
154
+ def set_network(self, network):
155
+ self.network = network
156
+
157
+ # freezeしてマージする
158
+ def merge_to(self, sd, dtype, device):
159
+ # get up/down weight
160
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
161
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
162
+
163
+ # extract weight from org_module
164
+ org_sd = self.org_module.state_dict()
165
+ weight = org_sd["weight"].to(torch.float)
166
+
167
+ # merge weight
168
+ if len(weight.size()) == 2:
169
+ # linear
170
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
171
+ elif down_weight.size()[2:4] == (1, 1):
172
+ # conv2d 1x1
173
+ weight = (
174
+ weight
175
+ + self.multiplier
176
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
177
+ * self.scale
178
+ )
179
+ else:
180
+ # conv2d 3x3
181
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
182
+ # print(conved.size(), weight.size(), module.stride, module.padding)
183
+ weight = weight + self.multiplier * conved * self.scale
184
+
185
+ # set weight to org_module
186
+ org_sd["weight"] = weight.to(dtype)
187
+ self.org_module.load_state_dict(org_sd)
188
+
189
+ # 復元できるマージのため、このモジュールのweightを返す
190
+ def get_weight(self, multiplier=None):
191
+ if multiplier is None:
192
+ multiplier = self.multiplier
193
+
194
+ # get up/down weight from module
195
+ up_weight = self.lora_up.weight.to(torch.float)
196
+ down_weight = self.lora_down.weight.to(torch.float)
197
+
198
+ # pre-calculated weight
199
+ if len(down_weight.size()) == 2:
200
+ # linear
201
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
202
+ elif down_weight.size()[2:4] == (1, 1):
203
+ # conv2d 1x1
204
+ weight = (
205
+ self.multiplier
206
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
207
+ * self.scale
208
+ )
209
+ else:
210
+ # conv2d 3x3
211
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
212
+ weight = self.multiplier * conved * self.scale
213
+
214
+ return weight
215
+
216
+ def set_region(self, region):
217
+ self.region = region
218
+ self.region_mask = None
219
+
220
+ def default_forward(self, x):
221
+ # print("default_forward", self.lora_name, x.size())
222
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
223
+
224
+ def forward(self, x):
225
+ if not self.enabled:
226
+ return self.org_forward(x)
227
+
228
+ if self.network is None or self.network.sub_prompt_index is None:
229
+ return self.default_forward(x)
230
+ if not self.regional and not self.use_sub_prompt:
231
+ return self.default_forward(x)
232
+
233
+ if self.regional:
234
+ return self.regional_forward(x)
235
+ else:
236
+ return self.sub_prompt_forward(x)
237
+
238
+ def get_mask_for_x(self, x):
239
+ # calculate size from shape of x
240
+ if len(x.size()) == 4:
241
+ h, w = x.size()[2:4]
242
+ area = h * w
243
+ else:
244
+ area = x.size()[1]
245
+
246
+ mask = self.network.mask_dic[area]
247
+ if mask is None:
248
+ raise ValueError(f"mask is None for resolution {area}")
249
+ if len(x.size()) != 4:
250
+ mask = torch.reshape(mask, (1, -1, 1))
251
+ return mask
252
+
253
+ def regional_forward(self, x):
254
+ if "attn2_to_out" in self.lora_name:
255
+ return self.to_out_forward(x)
256
+
257
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
258
+ return self.default_forward(x)
259
+
260
+ # apply mask for LoRA result
261
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
262
+ mask = self.get_mask_for_x(lx)
263
+ # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
264
+ lx = lx * mask
265
+
266
+ x = self.org_forward(x)
267
+ x = x + lx
268
+
269
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
270
+ x = self.postp_to_q(x)
271
+
272
+ return x
273
+
274
+ def postp_to_q(self, x):
275
+ # repeat x to num_sub_prompts
276
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
277
+ qc = self.network.batch_size # uncond
278
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
279
+ if has_real_uncond:
280
+ qc += self.network.batch_size # real_uncond
281
+
282
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
283
+ query[: self.network.batch_size] = x[: self.network.batch_size]
284
+
285
+ for i in range(self.network.batch_size):
286
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
287
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
288
+
289
+ if has_real_uncond:
290
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
291
+
292
+ # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
293
+ return query
294
+
295
+ def sub_prompt_forward(self, x):
296
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
297
+ return self.org_forward(x)
298
+
299
+ emb_idx = self.network.sub_prompt_index
300
+ if not self.text_encoder:
301
+ emb_idx += self.network.batch_size
302
+
303
+ # apply sub prompt of X
304
+ lx = x[emb_idx :: self.network.num_sub_prompts]
305
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
306
+
307
+ # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
308
+
309
+ x = self.org_forward(x)
310
+ x[emb_idx :: self.network.num_sub_prompts] += lx
311
+
312
+ return x
313
+
314
+ def to_out_forward(self, x):
315
+ # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
316
+
317
+ if self.network.is_last_network:
318
+ masks = [None] * self.network.num_sub_prompts
319
+ self.network.shared[self.lora_name] = (None, masks)
320
+ else:
321
+ lx, masks = self.network.shared[self.lora_name]
322
+
323
+ # call own LoRA
324
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
325
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
326
+
327
+ if self.network.is_last_network:
328
+ lx = torch.zeros(
329
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
330
+ )
331
+ self.network.shared[self.lora_name] = (lx, masks)
332
+
333
+ # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
334
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
335
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
336
+
337
+ # if not last network, return x and masks
338
+ x = self.org_forward(x)
339
+ if not self.network.is_last_network:
340
+ return x
341
+
342
+ lx, masks = self.network.shared.pop(self.lora_name)
343
+
344
+ # if last network, combine separated x with mask weighted sum
345
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
346
+
347
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
348
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
349
+ if has_real_uncond:
350
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
351
+
352
+ # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
353
+ # for i in range(len(masks)):
354
+ # if masks[i] is None:
355
+ # masks[i] = torch.zeros_like(masks[-1])
356
+
357
+ mask = torch.cat(masks)
358
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
359
+ for i in range(self.network.batch_size):
360
+ # 1枚の画像ごとに処理する
361
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
362
+ lx1 = lx1 * mask
363
+ lx1 = torch.sum(lx1, dim=0)
364
+
365
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
366
+ x1 = x[xi : xi + self.network.num_sub_prompts]
367
+ x1 = x1 * mask
368
+ x1 = torch.sum(x1, dim=0)
369
+ x1 = x1 / mask_sum
370
+
371
+ x1 = x1 + lx1
372
+ out[self.network.batch_size + i] = x1
373
+
374
+ # print("to_out_forward", x.size(), out.size(), has_real_uncond)
375
+ return out
376
+
377
+
378
+ def parse_block_lr_kwargs(nw_kwargs):
379
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
380
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
381
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
382
+
383
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
384
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
385
+ return None, None, None
386
+
387
+ # extract learning rate weight for each block
388
+ if down_lr_weight is not None:
389
+ # if some parameters are not set, use zero
390
+ if "," in down_lr_weight:
391
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
392
+
393
+ if mid_lr_weight is not None:
394
+ mid_lr_weight = float(mid_lr_weight)
395
+
396
+ if up_lr_weight is not None:
397
+ if "," in up_lr_weight:
398
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
399
+
400
+ down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
401
+ down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
402
+ )
403
+
404
+ return down_lr_weight, mid_lr_weight, up_lr_weight
405
+
406
+
407
+ def create_network(
408
+ multiplier: float,
409
+ network_dim: Optional[int],
410
+ network_alpha: Optional[float],
411
+ vae: AutoencoderKL,
412
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
413
+ unet,
414
+ neuron_dropout: Optional[float] = None,
415
+ **kwargs,
416
+ ):
417
+ if network_dim is None:
418
+ network_dim = 4 # default
419
+ if network_alpha is None:
420
+ network_alpha = 1.0
421
+
422
+ # extract dim/alpha for conv2d, and block dim
423
+ conv_dim = kwargs.get("conv_dim", None)
424
+ conv_alpha = kwargs.get("conv_alpha", None)
425
+ if conv_dim is not None:
426
+ conv_dim = int(conv_dim)
427
+ if conv_alpha is None:
428
+ conv_alpha = 1.0
429
+ else:
430
+ conv_alpha = float(conv_alpha)
431
+
432
+ # block dim/alpha/lr
433
+ block_dims = kwargs.get("block_dims", None)
434
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
435
+
436
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
437
+ if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
438
+ block_alphas = kwargs.get("block_alphas", None)
439
+ conv_block_dims = kwargs.get("conv_block_dims", None)
440
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
441
+
442
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
443
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
444
+ )
445
+
446
+ # remove block dim/alpha without learning rate
447
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
448
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
449
+ )
450
+
451
+ else:
452
+ block_alphas = None
453
+ conv_block_dims = None
454
+ conv_block_alphas = None
455
+
456
+ # rank/module dropout
457
+ rank_dropout = kwargs.get("rank_dropout", None)
458
+ if rank_dropout is not None:
459
+ rank_dropout = float(rank_dropout)
460
+ module_dropout = kwargs.get("module_dropout", None)
461
+ if module_dropout is not None:
462
+ module_dropout = float(module_dropout)
463
+
464
+ # すごく引数が多いな ( ^ω^)・・・
465
+ network = LoRANetwork(
466
+ text_encoder,
467
+ unet,
468
+ multiplier=multiplier,
469
+ lora_dim=network_dim,
470
+ alpha=network_alpha,
471
+ dropout=neuron_dropout,
472
+ rank_dropout=rank_dropout,
473
+ module_dropout=module_dropout,
474
+ conv_lora_dim=conv_dim,
475
+ conv_alpha=conv_alpha,
476
+ block_dims=block_dims,
477
+ block_alphas=block_alphas,
478
+ conv_block_dims=conv_block_dims,
479
+ conv_block_alphas=conv_block_alphas,
480
+ varbose=True,
481
+ )
482
+
483
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
484
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
485
+
486
+ return network
487
+
488
+
489
+ # このメソッドは外部から呼び出される可能性を考慮しておく
490
+ # network_dim, network_alpha にはデフォルト値が入っている。
491
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
492
+ # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
493
+ def get_block_dims_and_alphas(
494
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
495
+ ):
496
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
497
+
498
+ def parse_ints(s):
499
+ return [int(i) for i in s.split(",")]
500
+
501
+ def parse_floats(s):
502
+ return [float(i) for i in s.split(",")]
503
+
504
+ # block_dimsとblock_alphasをパースする。必ず値が入る
505
+ if block_dims is not None:
506
+ block_dims = parse_ints(block_dims)
507
+ assert (
508
+ len(block_dims) == num_total_blocks
509
+ ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
510
+ else:
511
+ print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
512
+ block_dims = [network_dim] * num_total_blocks
513
+
514
+ if block_alphas is not None:
515
+ block_alphas = parse_floats(block_alphas)
516
+ assert (
517
+ len(block_alphas) == num_total_blocks
518
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
519
+ else:
520
+ print(
521
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になり���す"
522
+ )
523
+ block_alphas = [network_alpha] * num_total_blocks
524
+
525
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
526
+ if conv_block_dims is not None:
527
+ conv_block_dims = parse_ints(conv_block_dims)
528
+ assert (
529
+ len(conv_block_dims) == num_total_blocks
530
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
531
+
532
+ if conv_block_alphas is not None:
533
+ conv_block_alphas = parse_floats(conv_block_alphas)
534
+ assert (
535
+ len(conv_block_alphas) == num_total_blocks
536
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
537
+ else:
538
+ if conv_alpha is None:
539
+ conv_alpha = 1.0
540
+ print(
541
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
542
+ )
543
+ conv_block_alphas = [conv_alpha] * num_total_blocks
544
+ else:
545
+ if conv_dim is not None:
546
+ print(
547
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
548
+ )
549
+ conv_block_dims = [conv_dim] * num_total_blocks
550
+ conv_block_alphas = [conv_alpha] * num_total_blocks
551
+ else:
552
+ conv_block_dims = None
553
+ conv_block_alphas = None
554
+
555
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
556
+
557
+
558
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
559
+ def get_block_lr_weight(
560
+ down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
561
+ ) -> Tuple[List[float], List[float], List[float]]:
562
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
563
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
564
+ return None, None, None
565
+
566
+ max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
567
+
568
+ def get_list(name_with_suffix) -> List[float]:
569
+ import math
570
+
571
+ tokens = name_with_suffix.split("+")
572
+ name = tokens[0]
573
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
574
+
575
+ if name == "cosine":
576
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
577
+ elif name == "sine":
578
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
579
+ elif name == "linear":
580
+ return [i / (max_len - 1) + base_lr for i in range(max_len)]
581
+ elif name == "reverse_linear":
582
+ return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
583
+ elif name == "zeros":
584
+ return [0.0 + base_lr] * max_len
585
+ else:
586
+ print(
587
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
588
+ % (name)
589
+ )
590
+ return None
591
+
592
+ if type(down_lr_weight) == str:
593
+ down_lr_weight = get_list(down_lr_weight)
594
+ if type(up_lr_weight) == str:
595
+ up_lr_weight = get_list(up_lr_weight)
596
+
597
+ if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
598
+ print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
599
+ print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
600
+ up_lr_weight = up_lr_weight[:max_len]
601
+ down_lr_weight = down_lr_weight[:max_len]
602
+
603
+ if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
604
+ print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
605
+ print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
606
+
607
+ if down_lr_weight != None and len(down_lr_weight) < max_len:
608
+ down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
609
+ if up_lr_weight != None and len(up_lr_weight) < max_len:
610
+ up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
611
+
612
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
613
+ print("apply block learning rate / 階層別学習率を適用します。")
614
+ if down_lr_weight != None:
615
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
616
+ print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
617
+ else:
618
+ print("down_lr_weight: all 1.0, すべて1.0")
619
+
620
+ if mid_lr_weight != None:
621
+ mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
622
+ print("mid_lr_weight:", mid_lr_weight)
623
+ else:
624
+ print("mid_lr_weight: 1.0")
625
+
626
+ if up_lr_weight != None:
627
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
628
+ print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
629
+ else:
630
+ print("up_lr_weight: all 1.0, すべて1.0")
631
+
632
+ return down_lr_weight, mid_lr_weight, up_lr_weight
633
+
634
+
635
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
636
+ def remove_block_dims_and_alphas(
637
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
638
+ ):
639
+ # set 0 to block dim without learning rate to remove the block
640
+ if down_lr_weight != None:
641
+ for i, lr in enumerate(down_lr_weight):
642
+ if lr == 0:
643
+ block_dims[i] = 0
644
+ if conv_block_dims is not None:
645
+ conv_block_dims[i] = 0
646
+ if mid_lr_weight != None:
647
+ if mid_lr_weight == 0:
648
+ block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
649
+ if conv_block_dims is not None:
650
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
651
+ if up_lr_weight != None:
652
+ for i, lr in enumerate(up_lr_weight):
653
+ if lr == 0:
654
+ block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
655
+ if conv_block_dims is not None:
656
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
657
+
658
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
659
+
660
+
661
+ # 外部から呼び出す可能性を考慮しておく
662
+ def get_block_index(lora_name: str) -> int:
663
+ block_idx = -1 # invalid lora name
664
+
665
+ m = RE_UPDOWN.search(lora_name)
666
+ if m:
667
+ g = m.groups()
668
+ i = int(g[1])
669
+ j = int(g[3])
670
+ if g[2] == "resnets":
671
+ idx = 3 * i + j
672
+ elif g[2] == "attentions":
673
+ idx = 3 * i + j
674
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
675
+ idx = 3 * i + 2
676
+
677
+ if g[0] == "down":
678
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
679
+ elif g[0] == "up":
680
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
681
+
682
+ elif "mid_block_" in lora_name:
683
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
684
+
685
+ return block_idx
686
+
687
+
688
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
689
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
690
+ if weights_sd is None:
691
+ if os.path.splitext(file)[1] == ".safetensors":
692
+ from safetensors.torch import load_file, safe_open
693
+
694
+ weights_sd = load_file(file)
695
+ else:
696
+ weights_sd = torch.load(file, map_location="cpu")
697
+
698
+ # get dim/alpha mapping
699
+ modules_dim = {}
700
+ modules_alpha = {}
701
+ for key, value in weights_sd.items():
702
+ if "." not in key:
703
+ continue
704
+
705
+ lora_name = key.split(".")[0]
706
+ if "alpha" in key:
707
+ modules_alpha[lora_name] = value
708
+ elif "lora_down" in key:
709
+ dim = value.size()[0]
710
+ modules_dim[lora_name] = dim
711
+ # print(lora_name, value.size(), dim)
712
+
713
+ # support old LoRA without alpha
714
+ for key in modules_dim.keys():
715
+ if key not in modules_alpha:
716
+ modules_alpha[key] = modules_dim[key]
717
+
718
+ module_class = LoRAInfModule if for_inference else LoRAModule
719
+
720
+ network = LoRANetwork(
721
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
722
+ )
723
+
724
+ # block lr
725
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
726
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
727
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
728
+
729
+ return network, weights_sd
730
+
731
+
732
+ class LoRANetwork(torch.nn.Module):
733
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
734
+
735
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
736
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
737
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
738
+ LORA_PREFIX_UNET = "lora_unet"
739
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
740
+
741
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
742
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
743
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
744
+
745
+ def __init__(
746
+ self,
747
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
748
+ unet,
749
+ multiplier: float = 1.0,
750
+ lora_dim: int = 4,
751
+ alpha: float = 1,
752
+ dropout: Optional[float] = None,
753
+ rank_dropout: Optional[float] = None,
754
+ module_dropout: Optional[float] = None,
755
+ conv_lora_dim: Optional[int] = None,
756
+ conv_alpha: Optional[float] = None,
757
+ block_dims: Optional[List[int]] = None,
758
+ block_alphas: Optional[List[float]] = None,
759
+ conv_block_dims: Optional[List[int]] = None,
760
+ conv_block_alphas: Optional[List[float]] = None,
761
+ modules_dim: Optional[Dict[str, int]] = None,
762
+ modules_alpha: Optional[Dict[str, int]] = None,
763
+ module_class: Type[object] = LoRAModule,
764
+ varbose: Optional[bool] = False,
765
+ ) -> None:
766
+ """
767
+ LoRA network: すごく引数が多いが、パターンは以下の通り
768
+ 1. lora_dimとalphaを指定
769
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
770
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
771
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
772
+ 5. modules_dimとmodules_alphaを指定 (推論用)
773
+ """
774
+ super().__init__()
775
+ self.multiplier = multiplier
776
+
777
+ self.lora_dim = lora_dim
778
+ self.alpha = alpha
779
+ self.conv_lora_dim = conv_lora_dim
780
+ self.conv_alpha = conv_alpha
781
+ self.dropout = dropout
782
+ self.rank_dropout = rank_dropout
783
+ self.module_dropout = module_dropout
784
+
785
+ if modules_dim is not None:
786
+ print(f"create LoRA network from weights")
787
+ elif block_dims is not None:
788
+ print(f"create LoRA network from block_dims")
789
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
790
+ print(f"block_dims: {block_dims}")
791
+ print(f"block_alphas: {block_alphas}")
792
+ if conv_block_dims is not None:
793
+ print(f"conv_block_dims: {conv_block_dims}")
794
+ print(f"conv_block_alphas: {conv_block_alphas}")
795
+ else:
796
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
797
+ print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
798
+ if self.conv_lora_dim is not None:
799
+ print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
800
+
801
+ # create module instances
802
+ def create_modules(
803
+ is_unet: bool,
804
+ text_encoder_idx: Optional[int], # None, 1, 2
805
+ root_module: torch.nn.Module,
806
+ target_replace_modules: List[torch.nn.Module],
807
+ ) -> List[LoRAModule]:
808
+ prefix = (
809
+ self.LORA_PREFIX_UNET
810
+ if is_unet
811
+ else (
812
+ self.LORA_PREFIX_TEXT_ENCODER
813
+ if text_encoder_idx is None
814
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
815
+ )
816
+ )
817
+ loras = []
818
+ skipped = []
819
+ for name, module in root_module.named_modules():
820
+ if module.__class__.__name__ in target_replace_modules:
821
+ for child_name, child_module in module.named_modules():
822
+ is_linear = child_module.__class__.__name__ == "Linear"
823
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
824
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
825
+
826
+ if is_linear or is_conv2d:
827
+ lora_name = prefix + "." + name + "." + child_name
828
+ lora_name = lora_name.replace(".", "_")
829
+
830
+ dim = None
831
+ alpha = None
832
+
833
+ if modules_dim is not None:
834
+ # モジュール指定あり
835
+ if lora_name in modules_dim:
836
+ dim = modules_dim[lora_name]
837
+ alpha = modules_alpha[lora_name]
838
+ elif is_unet and block_dims is not None:
839
+ # U-Netでblock_dims指定あり
840
+ block_idx = get_block_index(lora_name)
841
+ if is_linear or is_conv2d_1x1:
842
+ dim = block_dims[block_idx]
843
+ alpha = block_alphas[block_idx]
844
+ elif conv_block_dims is not None:
845
+ dim = conv_block_dims[block_idx]
846
+ alpha = conv_block_alphas[block_idx]
847
+ else:
848
+ # 通常、すべて対象とする
849
+ if is_linear or is_conv2d_1x1:
850
+ dim = self.lora_dim
851
+ alpha = self.alpha
852
+ elif self.conv_lora_dim is not None:
853
+ dim = self.conv_lora_dim
854
+ alpha = self.conv_alpha
855
+
856
+ if dim is None or dim == 0:
857
+ # skipした情報を出力
858
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
859
+ skipped.append(lora_name)
860
+ continue
861
+
862
+ lora = module_class(
863
+ lora_name,
864
+ child_module,
865
+ self.multiplier,
866
+ dim,
867
+ alpha,
868
+ dropout=dropout,
869
+ rank_dropout=rank_dropout,
870
+ module_dropout=module_dropout,
871
+ )
872
+ loras.append(lora)
873
+ return loras, skipped
874
+
875
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
876
+ print(text_encoders)
877
+ # create LoRA for text encoder
878
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
879
+ self.text_encoder_loras = []
880
+ skipped_te = []
881
+ for i, text_encoder in enumerate(text_encoders):
882
+ if len(text_encoders) > 1:
883
+ index = i + 1
884
+ print(f"create LoRA for Text Encoder {index}:")
885
+ else:
886
+ index = None
887
+ print(f"create LoRA for Text Encoder:")
888
+
889
+ print(text_encoder)
890
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
891
+ self.text_encoder_loras.extend(text_encoder_loras)
892
+ skipped_te += skipped
893
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
894
+
895
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
896
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
897
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
898
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
899
+
900
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
901
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
902
+
903
+ skipped = skipped_te + skipped_un
904
+ if varbose and len(skipped) > 0:
905
+ print(
906
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
907
+ )
908
+ for name in skipped:
909
+ print(f"\t{name}")
910
+
911
+ self.up_lr_weight: List[float] = None
912
+ self.down_lr_weight: List[float] = None
913
+ self.mid_lr_weight: float = None
914
+ self.block_lr = False
915
+
916
+ # assertion
917
+ names = set()
918
+ for lora in self.text_encoder_loras + self.unet_loras:
919
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
920
+ names.add(lora.lora_name)
921
+
922
+ def set_multiplier(self, multiplier):
923
+ self.multiplier = multiplier
924
+ for lora in self.text_encoder_loras + self.unet_loras:
925
+ lora.multiplier = self.multiplier
926
+
927
+ def load_weights(self, file):
928
+ if os.path.splitext(file)[1] == ".safetensors":
929
+ from safetensors.torch import load_file
930
+
931
+ weights_sd = load_file(file)
932
+ else:
933
+ weights_sd = torch.load(file, map_location="cpu")
934
+ info = self.load_state_dict(weights_sd, False)
935
+ return info
936
+
937
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
938
+ if apply_text_encoder:
939
+ print("enable LoRA for text encoder")
940
+ else:
941
+ self.text_encoder_loras = []
942
+
943
+ if apply_unet:
944
+ print("enable LoRA for U-Net")
945
+ else:
946
+ self.unet_loras = []
947
+
948
+ for lora in self.text_encoder_loras + self.unet_loras:
949
+ lora.apply_to()
950
+ self.add_module(lora.lora_name, lora)
951
+
952
+ # マージできるかどうかを返す
953
+ def is_mergeable(self):
954
+ return True
955
+
956
+ # TODO refactor to common function with apply_to
957
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
958
+ apply_text_encoder = apply_unet = False
959
+ for key in weights_sd.keys():
960
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
961
+ apply_text_encoder = True
962
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
963
+ apply_unet = True
964
+
965
+ if apply_text_encoder:
966
+ print("enable LoRA for text encoder")
967
+ else:
968
+ self.text_encoder_loras = []
969
+
970
+ if apply_unet:
971
+ print("enable LoRA for U-Net")
972
+ else:
973
+ self.unet_loras = []
974
+
975
+ for lora in self.text_encoder_loras + self.unet_loras:
976
+ sd_for_lora = {}
977
+ for key in weights_sd.keys():
978
+ if key.startswith(lora.lora_name):
979
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
980
+ lora.merge_to(sd_for_lora, dtype, device)
981
+
982
+ print(f"weights are merged")
983
+
984
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
985
+ def set_block_lr_weight(
986
+ self,
987
+ up_lr_weight: List[float] = None,
988
+ mid_lr_weight: float = None,
989
+ down_lr_weight: List[float] = None,
990
+ ):
991
+ self.block_lr = True
992
+ self.down_lr_weight = down_lr_weight
993
+ self.mid_lr_weight = mid_lr_weight
994
+ self.up_lr_weight = up_lr_weight
995
+
996
+ def get_lr_weight(self, lora: LoRAModule) -> float:
997
+ lr_weight = 1.0
998
+ block_idx = get_block_index(lora.lora_name)
999
+ if block_idx < 0:
1000
+ return lr_weight
1001
+
1002
+ if block_idx < LoRANetwork.NUM_OF_BLOCKS:
1003
+ if self.down_lr_weight != None:
1004
+ lr_weight = self.down_lr_weight[block_idx]
1005
+ elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
1006
+ if self.mid_lr_weight != None:
1007
+ lr_weight = self.mid_lr_weight
1008
+ elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
1009
+ if self.up_lr_weight != None:
1010
+ lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
1011
+
1012
+ return lr_weight
1013
+
1014
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1015
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1016
+ self.requires_grad_(True)
1017
+ all_params = []
1018
+
1019
+ def enumerate_params(loras):
1020
+ params = []
1021
+ for lora in loras:
1022
+ params.extend(lora.parameters())
1023
+ return params
1024
+
1025
+ if self.text_encoder_loras:
1026
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
1027
+ if text_encoder_lr is not None:
1028
+ param_data["lr"] = text_encoder_lr
1029
+ all_params.append(param_data)
1030
+
1031
+ if self.unet_loras:
1032
+ if self.block_lr:
1033
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1034
+ block_idx_to_lora = {}
1035
+ for lora in self.unet_loras:
1036
+ idx = get_block_index(lora.lora_name)
1037
+ if idx not in block_idx_to_lora:
1038
+ block_idx_to_lora[idx] = []
1039
+ block_idx_to_lora[idx].append(lora)
1040
+
1041
+ # blockごとにパラメータを設定する
1042
+ for idx, block_loras in block_idx_to_lora.items():
1043
+ param_data = {"params": enumerate_params(block_loras)}
1044
+
1045
+ if unet_lr is not None:
1046
+ param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1047
+ elif default_lr is not None:
1048
+ param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1049
+ if ("lr" in param_data) and (param_data["lr"] == 0):
1050
+ continue
1051
+ all_params.append(param_data)
1052
+
1053
+ else:
1054
+ param_data = {"params": enumerate_params(self.unet_loras)}
1055
+ if unet_lr is not None:
1056
+ param_data["lr"] = unet_lr
1057
+ all_params.append(param_data)
1058
+
1059
+ return all_params
1060
+
1061
+ def enable_gradient_checkpointing(self):
1062
+ # not supported
1063
+ pass
1064
+
1065
+ def prepare_grad_etc(self, text_encoder, unet):
1066
+ self.requires_grad_(True)
1067
+
1068
+ def on_epoch_start(self, text_encoder, unet):
1069
+ self.train()
1070
+
1071
+ def get_trainable_params(self):
1072
+ return self.parameters()
1073
+
1074
+ def save_weights(self, file, dtype, metadata):
1075
+ if metadata is not None and len(metadata) == 0:
1076
+ metadata = None
1077
+
1078
+ state_dict = self.state_dict()
1079
+
1080
+ if dtype is not None:
1081
+ for key in list(state_dict.keys()):
1082
+ v = state_dict[key]
1083
+ v = v.detach().clone().to("cpu").to(dtype)
1084
+ state_dict[key] = v
1085
+
1086
+ if os.path.splitext(file)[1] == ".safetensors":
1087
+ from safetensors.torch import save_file
1088
+ from library import train_util
1089
+
1090
+ # Precalculate model hashes to save time on indexing
1091
+ if metadata is None:
1092
+ metadata = {}
1093
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1094
+ metadata["sshs_model_hash"] = model_hash
1095
+ metadata["sshs_legacy_hash"] = legacy_hash
1096
+
1097
+ save_file(state_dict, file, metadata)
1098
+ else:
1099
+ torch.save(state_dict, file)
1100
+
1101
+ # mask is a tensor with values from 0 to 1
1102
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1103
+ if mask.max() == 0:
1104
+ mask = torch.ones_like(mask)
1105
+
1106
+ self.mask = mask
1107
+ self.sub_prompt_index = sub_prompt_index
1108
+ self.is_last_network = is_last_network
1109
+
1110
+ for lora in self.text_encoder_loras + self.unet_loras:
1111
+ lora.set_network(self)
1112
+
1113
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1114
+ self.batch_size = batch_size
1115
+ self.num_sub_prompts = num_sub_prompts
1116
+ self.current_size = (height, width)
1117
+ self.shared = shared
1118
+
1119
+ # create masks
1120
+ mask = self.mask
1121
+ mask_dic = {}
1122
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1123
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1124
+ dtype = ref_weight.dtype
1125
+ device = ref_weight.device
1126
+
1127
+ def resize_add(mh, mw):
1128
+ # print(mh, mw, mh * mw)
1129
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1130
+ m = m.to(device, dtype=dtype)
1131
+ mask_dic[mh * mw] = m
1132
+
1133
+ h = height // 8
1134
+ w = width // 8
1135
+ for _ in range(4):
1136
+ resize_add(h, w)
1137
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1138
+ resize_add(h + h % 2, w + w % 2)
1139
+ h = (h + 1) // 2
1140
+ w = (w + 1) // 2
1141
+
1142
+ self.mask_dic = mask_dic
1143
+
1144
+ def backup_weights(self):
1145
+ # 重みのバックアップを行う
1146
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1147
+ for lora in loras:
1148
+ org_module = lora.org_module_ref[0]
1149
+ if not hasattr(org_module, "_lora_org_weight"):
1150
+ sd = org_module.state_dict()
1151
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1152
+ org_module._lora_restored = True
1153
+
1154
+ def restore_weights(self):
1155
+ # 重みのリストアを行う
1156
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1157
+ for lora in loras:
1158
+ org_module = lora.org_module_ref[0]
1159
+ if not org_module._lora_restored:
1160
+ sd = org_module.state_dict()
1161
+ sd["weight"] = org_module._lora_org_weight
1162
+ org_module.load_state_dict(sd)
1163
+ org_module._lora_restored = True
1164
+
1165
+ def pre_calculation(self):
1166
+ # 事前計算を行う
1167
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1168
+ for lora in loras:
1169
+ org_module = lora.org_module_ref[0]
1170
+ sd = org_module.state_dict()
1171
+
1172
+ org_weight = sd["weight"]
1173
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1174
+ sd["weight"] = org_weight + lora_weight
1175
+ assert sd["weight"].shape == org_weight.shape
1176
+ org_module.load_state_dict(sd)
1177
+
1178
+ org_module._lora_restored = False
1179
+ lora.enabled = False
1180
+
1181
+ def apply_max_norm_regularization(self, max_norm_value, device):
1182
+ downkeys = []
1183
+ upkeys = []
1184
+ alphakeys = []
1185
+ norms = []
1186
+ keys_scaled = 0
1187
+
1188
+ state_dict = self.state_dict()
1189
+ for key in state_dict.keys():
1190
+ if "lora_down" in key and "weight" in key:
1191
+ downkeys.append(key)
1192
+ upkeys.append(key.replace("lora_down", "lora_up"))
1193
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1194
+
1195
+ for i in range(len(downkeys)):
1196
+ down = state_dict[downkeys[i]].to(device)
1197
+ up = state_dict[upkeys[i]].to(device)
1198
+ alpha = state_dict[alphakeys[i]].to(device)
1199
+ dim = down.shape[0]
1200
+ scale = alpha / dim
1201
+
1202
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1203
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1204
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1205
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1206
+ else:
1207
+ updown = up @ down
1208
+
1209
+ updown *= scale
1210
+
1211
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1212
+ desired = torch.clamp(norm, max=max_norm_value)
1213
+ ratio = desired.cpu() / norm.cpu()
1214
+ sqrt_ratio = ratio**0.5
1215
+ if ratio != 1:
1216
+ keys_scaled += 1
1217
+ state_dict[upkeys[i]] *= sqrt_ratio
1218
+ state_dict[downkeys[i]] *= sqrt_ratio
1219
+ scalednorm = updown.norm() * ratio
1220
+ norms.append(scalednorm.item())
1221
+
1222
+ return keys_scaled, sum(norms) / len(norms), max(norms)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git@aedd78767c99f7bc26a532622d4006280cc6c00d
2
+ transformers
3
+ safetensors
4
+ accelerate
sdxl_loras.json ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "image": "images/pixel-art-xl.jpeg",
4
+ "title": "Pixel Art XL",
5
+ "repo": "nerijs/pixel-art-xl",
6
+ "trigger_word": "pixel art",
7
+ "weights": "pixel-art-xl.safetensors",
8
+ "is_compatible": true
9
+ },
10
+ {
11
+ "image": "images/riding-min.jpg",
12
+ "title": "Tintin AI",
13
+ "repo": "Pclanglais/TintinIA",
14
+ "trigger_word": "drawing of tintin",
15
+ "weights": "pytorch_lora_weights.safetensors",
16
+ "is_compatible": true,
17
+ "is_nc": true
18
+ },
19
+ {
20
+ "image": "https://huggingface.co/ProomptEngineer/pe-balloon-diffusion-style/resolve/main/2095176.jpeg",
21
+ "title": "PE Balloon Diffusion",
22
+ "repo": "ProomptEngineer/pe-balloon-diffusion-style",
23
+ "trigger_word": "PEBalloonStyle",
24
+ "weights": "PE_BalloonStyle.safetensors",
25
+ "is_compatible": true
26
+ },
27
+ {
28
+ "image": "https://huggingface.co/joachimsallstrom/aether-cloud-lora-for-sdxl/resolve/main/2378710.jpeg",
29
+ "title": "Aether Cloud",
30
+ "repo": "joachimsallstrom/aether-cloud-lora-for-sdxl",
31
+ "trigger_word": "a cloud that looks like a",
32
+ "weights": "Aether_Cloud_v1.safetensors",
33
+ "is_compatible": true
34
+ },
35
+ {
36
+ "image": "images/crayon.png",
37
+ "title": "Crayon Style",
38
+ "repo": "ostris/crayon_style_lora_sdxl",
39
+ "trigger_word": "",
40
+ "weights": "crayons_v1_sdxl.safetensors",
41
+ "is_compatible": true
42
+ },
43
+ {
44
+ "image": "https://tjzk.replicate.delivery/models_models_cover_image/c8b21524-342a-4dd2-bb01-3e65349ed982/image_12.jpeg",
45
+ "title": "Zelda 64 SDXL",
46
+ "repo":"jbilcke-hf/sdxl-zelda64",
47
+ "trigger_word": "in the style of <s0><s1>",
48
+ "weights": "lora.safetensors",
49
+ "text_embedding_weights": "embeddings.pti",
50
+ "is_compatible": false,
51
+ "is_pivotal": true
52
+ },
53
+ {
54
+ "image": "images/papercut_SDXL.jpeg",
55
+ "title": "Papercut SDXL",
56
+ "repo": "TheLastBen/Papercut_SDXL",
57
+ "trigger_word": "papercut",
58
+ "weights": "papercut.safetensors",
59
+ "is_compatible": true
60
+ },
61
+ {
62
+ "image": "https://pbxt.replicate.delivery/8LKCty2D5b5BBBjylErfI8Xqf4OTSsnA0TIJccnpPct3GmeiA/out-0.png",
63
+ "title": "2004 bad digital photography",
64
+ "repo": "fofr/sdxl-2004",
65
+ "trigger_word": "2004, in the style of <s0><s1>",
66
+ "weights": "lora.safetensors",
67
+ "text_embedding_weights": "embeddings.pti",
68
+ "is_compatible": false,
69
+ "is_pivotal": true
70
+ },
71
+ {
72
+ "image": "https://huggingface.co/joachimsallstrom/aether-ghost-lora-for-sdxl/resolve/14de4e59a3f44dabc762855da208cb8f44a7ac78/ghost.png",
73
+ "title": "Aether Ghost",
74
+ "repo": "joachimsallstrom/aether-ghost-lora-for-sdxl",
75
+ "trigger_word": "transparent ghost",
76
+ "weights": "Aether_Ghost_v1.1_LoRA.safetensors",
77
+ "is_compatible": true
78
+ },
79
+ {
80
+ "image": "https://i.imgur.com/Su4bFgm.png",
81
+ "title": "Vulcan SDXL",
82
+ "repo": "davizca87/vulcan",
83
+ "trigger_word": "v5lcn",
84
+ "weights": "v5lcnXL-000004.safetensors",
85
+ "is_compatible": true
86
+ },
87
+ {
88
+ "image":"https://huggingface.co/artificialguybr/ColoringBookRedmond/resolve/main/00009-1364020674.png",
89
+ "title": "ColoringBook.Redmond",
90
+ "repo": "artificialguybr/ColoringBookRedmond",
91
+ "trigger_word": "ColoringBookAF",
92
+ "weights": "ColoringBookRedmond-ColoringBookAF.safetensors",
93
+ "is_compatible": true
94
+ },
95
+ {
96
+ "image": "https://huggingface.co/Norod78/SDXL-LofiGirl-Lora/resolve/main/SDXL-LofiGirl-Lora/Examples/_00044-20230829080050-45-the%20%20girl%20with%20a%20pearl%20earring%20the%20LofiGirl%20%20_lora_SDXL-LofiGirl-Lora_1_%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sharp%20image.jpg",
97
+ "title": "LoFi Girl SDXL",
98
+ "repo": "Norod78/SDXL-LofiGirl-Lora",
99
+ "trigger_word": "LofiGirl",
100
+ "weights": "SDXL-LofiGirl-Lora.safetensors",
101
+ "is_compatible": true
102
+ },
103
+ {
104
+ "image": "images/embroid.png",
105
+ "title": "Embroidery Style",
106
+ "repo": "ostris/embroidery_style_lora_sdxl",
107
+ "trigger_word": "",
108
+ "weights": "embroidered_style_v1_sdxl.safetensors",
109
+ "is_compatible": true
110
+ },
111
+ {
112
+ "image": "images/3d_style_4.jpeg",
113
+ "title": "3D Render Style",
114
+ "repo": "goofyai/3d_render_style_xl",
115
+ "trigger_word": "3d style",
116
+ "weights": "3d_render_style_xl.safetensors",
117
+ "is_compatible": true
118
+ },
119
+ {
120
+ "image": "images/watercolor.png",
121
+ "title": "Watercolor Style",
122
+ "repo": "ostris/watercolor_style_lora_sdxl",
123
+ "trigger_word": "",
124
+ "weights": "watercolor_v1_sdxl.safetensors",
125
+ "is_compatible": true
126
+ },
127
+ {
128
+ "image": "https://huggingface.co/veryVANYA/ps1-graphics-sdxl/resolve/main/2070471.jpeg",
129
+ "title": "PS1 Graphics v2 SDXL",
130
+ "repo":"veryVANYA/ps1-graphics-sdxl-v2",
131
+ "trigger_word": "ps1 style",
132
+ "weights": "ps1_style_SDXL_v2.safetensors",
133
+ "is_compatible": true
134
+ },
135
+ {
136
+ "image": "images/william_eggleston.webp",
137
+ "title": "William Eggleston Style",
138
+ "repo": "TheLastBen/William_Eggleston_Style_SDXL",
139
+ "trigger_word": "by william eggleston",
140
+ "weights": "wegg.safetensors",
141
+ "is_compatible": true
142
+ },
143
+ {
144
+ "image": "https://huggingface.co/davizca87/c-a-g-coinmaker/resolve/main/1722160.jpeg",
145
+ "title": "CAG Coinmaker",
146
+ "repo": "davizca87/c-a-g-coinmaker",
147
+ "trigger_word": "c01n",
148
+ "weights": "c01n-000010.safetensors",
149
+ "is_compatible": true
150
+ },
151
+ {
152
+ "image": "images/dog.png",
153
+ "title": "Cyborg Style",
154
+ "repo": "goofyai/cyborg_style_xl",
155
+ "trigger_word": "cyborg style",
156
+ "weights": "cyborg_style_xl-off.safetensors",
157
+ "is_compatible": true
158
+ },
159
+ {
160
+ "image": "images/ToyRedmond-ToyLoraForSDXL10.png",
161
+ "title": "Toy.Redmond",
162
+ "repo": "artificialguybr/ToyRedmond-ToyLoraForSDXL10",
163
+ "trigger_word": "FnkRedmAF",
164
+ "weights": "ToyRedmond-FnkRedmAF.safetensors",
165
+ "is_compatible": true
166
+ },
167
+ {
168
+ "image": "images/voxel-xl-lora.png",
169
+ "title": "Voxel XL",
170
+ "repo": "Fictiverse/Voxel_XL_Lora",
171
+ "trigger_word": "voxel style",
172
+ "weights": "VoxelXL_v1.safetensors",
173
+ "is_compatible": true
174
+ },
175
+ {
176
+ "image": "images/uglysonic.webp",
177
+ "title": "Ugly Sonic",
178
+ "repo": "minimaxir/sdxl-ugly-sonic-lora",
179
+ "trigger_word": "sonic the hedgehog",
180
+ "weights": "pytorch_lora_weights.bin",
181
+ "is_compatible": true
182
+ },
183
+ {
184
+ "image": "images/corgi_brick.jpeg",
185
+ "title": "Lego BrickHeadz",
186
+ "repo": "nerijs/lego-brickheadz-xl",
187
+ "trigger_word": "lego brickheadz",
188
+ "weights": "legobrickheadz-v1.0-000004.safetensors",
189
+ "is_compatible": true
190
+ },
191
+ {
192
+ "image": "images/lego-minifig-xl.jpeg",
193
+ "title": "Lego Minifig XL",
194
+ "repo": "nerijs/lego-minifig-xl",
195
+ "trigger_word": "lego minifig",
196
+ "weights": "legominifig-v1.0-000003.safetensors",
197
+ "is_compatible": true
198
+ },
199
+ {
200
+ "image": "images/jojoso1.jpg",
201
+ "title": "JoJo's Bizarre style",
202
+ "repo": "Norod78/SDXL-jojoso_style-Lora",
203
+ "trigger_word": "jojoso style",
204
+ "weights": "SDXL-jojoso_style-Lora-r8.safetensors",
205
+ "is_compatible": true
206
+ },
207
+ {
208
+ "image": "images/pikachu.webp",
209
+ "title": "Pikachu XL",
210
+ "repo": "TheLastBen/Pikachu_SDXL",
211
+ "trigger_word": "pikachu",
212
+ "weights": "pikachu.safetensors",
213
+ "is_compatible": true
214
+ },
215
+ {
216
+ "image": "images/LogoRedmond-LogoLoraForSDXL.jpeg",
217
+ "title": "Logo.Redmond",
218
+ "repo": "artificialguybr/LogoRedmond-LogoLoraForSDXL",
219
+ "trigger_word": "LogoRedAF",
220
+ "weights": "LogoRedmond_LogoRedAF.safetensors",
221
+ "is_compatible": true
222
+ },
223
+ {
224
+ "image": "https://huggingface.co/Norod78/SDXL-StickerSheet-Lora/resolve/main/Examples/00073-20230831113700-7780-Cthulhu%20StickerSheet%20%20_lora_SDXL-StickerSheet-Lora_1_%2C%20based%20on%20H.P%20Lovecraft%20stories%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sharp.jpg",
225
+ "title": "Sticker Sheet",
226
+ "repo": "Norod78/SDXL-StickerSheet-Lora",
227
+ "trigger_word": "StickerSheet",
228
+ "weights": "SDXL-StickerSheet-Lora.safetensors",
229
+ "is_compatible": true
230
+ },
231
+ {
232
+ "image": "images/LineAni.Redmond.png",
233
+ "title": "LinearManga.Redmond",
234
+ "repo": "artificialguybr/LineAniRedmond-LinearMangaSDXL",
235
+ "trigger_word": "LineAniAF",
236
+ "weights": "LineAniRedmond-LineAniAF.safetensors",
237
+ "is_compatible": true
238
+ },
239
+ {
240
+ "image": "images/josef_koudelka.webp",
241
+ "title": "Josef Koudelka Style",
242
+ "repo": "TheLastBen/Josef_Koudelka_Style_SDXL",
243
+ "trigger_word": "by josef koudelka",
244
+ "weights": "koud.safetensors",
245
+ "is_compatible": true
246
+ },
247
+ {
248
+ "image": "https://huggingface.co/goofyai/Leonardo_Ai_Style_Illustration/resolve/main/leo-2.png",
249
+ "title": "Leonardo Style",
250
+ "repo": "goofyai/Leonardo_Ai_Style_Illustration",
251
+ "trigger_word": "leonardo style",
252
+ "weights": "leonardo_illustration.safetensors",
253
+ "is_compatible": true
254
+ },
255
+ {
256
+ "image":"https://huggingface.co/Norod78/SDXL-simpstyle-Lora/resolve/main/Examples/00006-20230820150225-558-the%20girl%20with%20a%20pearl%20earring%20by%20johannes%20vermeer%20simpstyle%20_lora_SDXL-simpstyle-Lora_1_%2C%20Very%20detailed%2C%20clean%2C%20high%20quality%2C%20sh.jpg",
257
+ "title": "SimpStyle",
258
+ "repo": "Norod78/SDXL-simpstyle-Lora",
259
+ "trigger_word":"simpstyle",
260
+ "weights": "SDXL-simpstyle-Lora-r8.safetensors",
261
+ "is_compatible": true
262
+ },
263
+ {
264
+ "image":"https://huggingface.co/artificialguybr/StoryBookRedmond/resolve/main/00162-1569823442.png",
265
+ "title": "Storybook.Redmond",
266
+ "repo": "artificialguybr/StoryBookRedmond",
267
+ "trigger_word":"KidsRedmAF",
268
+ "weights": "StoryBookRedmond-KidsRedmAF.safetensors",
269
+ "is_compatible": true
270
+ },
271
+ {
272
+ "image": "https://huggingface.co/chillpixel/blacklight-makeup-sdxl-lora/resolve/main/preview.png",
273
+ "title": "Blacklight Makeup",
274
+ "repo":"chillpixel/blacklight-makeup-sdxl-lora",
275
+ "trigger_word": "with blacklight makeup",
276
+ "weights": "pytorch_lora_weights.bin",
277
+ "is_compatible": true
278
+ }
279
+ ]
share_btn.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
2
+ <path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
3
+ <path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
4
+ </svg>"""
5
+
6
+ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
7
+ style="color: #ffffff;
8
+ "
9
+ xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
10
+
11
+ share_js = """async () => {
12
+ async function uploadFile(file){
13
+ const UPLOAD_URL = 'https://huggingface.co/uploads';
14
+ const response = await fetch(UPLOAD_URL, {
15
+ method: 'POST',
16
+ headers: {
17
+ 'Content-Type': file.type,
18
+ 'X-Requested-With': 'XMLHttpRequest',
19
+ },
20
+ body: file, /// <- File inherits from Blob
21
+ });
22
+ const url = await response.text();
23
+ return url;
24
+ }
25
+
26
+ async function getInputImgFile(imgEl){
27
+ const res = await fetch(imgEl.src);
28
+ const blob = await res.blob();
29
+ const imgId = Date.now() % 200;
30
+ const isPng = imgEl.src.startsWith(`data:image/png`);
31
+ if(isPng){
32
+ const fileName = `sd-perception-${{imgId}}.png`;
33
+ return new File([blob], fileName, { type: 'image/png' });
34
+ }else{
35
+ const fileName = `sd-perception-${{imgId}}.jpg`;
36
+ return new File([blob], fileName, { type: 'image/jpeg' });
37
+ }
38
+ }
39
+
40
+ const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
41
+ const selectedLoRA = gradioEl.querySelector('#selected_lora').innerHTML;
42
+ const inputPrompt = gradioEl.querySelector('#prompt input').value;
43
+ const outputImgEl = gradioEl.querySelector('#result-image img');
44
+
45
+ const shareBtnEl = gradioEl.querySelector('#share-btn');
46
+ const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
47
+ const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
48
+
49
+ shareBtnEl.style.pointerEvents = 'none';
50
+ shareIconEl.style.display = 'none';
51
+ loadingIconEl.style.removeProperty('display');
52
+
53
+ const inputFile = await getInputImgFile(outputImgEl);
54
+ const urlInputImg = await uploadFile(inputFile);
55
+
56
+ const descriptionMd = `
57
+
58
+ ${selectedLoRA}
59
+
60
+ ### Prompt
61
+ ${inputPrompt}
62
+
63
+ #### Generated Image:
64
+ <img src="${urlInputImg}" />
65
+ `;
66
+ const params = new URLSearchParams({
67
+ title: inputPrompt,
68
+ description: descriptionMd,
69
+ preview: true
70
+ });
71
+ const paramsStr = params.toString();
72
+ window.open(`https://huggingface.co/spaces/multimodalart/LoraTheExplorer/discussions/new?${paramsStr}`, '_blank');
73
+ shareBtnEl.style.removeProperty('pointer-events');
74
+ shareIconEl.style.removeProperty('display');
75
+ loadingIconEl.style.display = 'none';
76
+ }"""