barreloflube commited on
Commit
37112ef
Β·
1 Parent(s): 152e870

Refactor flux_helpers.py to enable or disable Vae

Browse files
Files changed (37) hide show
  1. app.py +23 -30
  2. config.py +52 -13
  3. modules/events/sdxl_events.py +0 -0
  4. modules/helpers/sdxl_helpers.py +0 -0
  5. old2/app.py +36 -0
  6. old2/config.py +36 -0
  7. old2/modules/events/common_events.py +264 -0
  8. {modules β†’ old2/modules}/events/flux_events.py +1 -150
  9. old2/modules/events/sdxl_events.py +170 -0
  10. {modules β†’ old2/modules}/helpers/common_helpers.py +3 -0
  11. {modules β†’ old2/modules}/helpers/flux_helpers.py +1 -1
  12. old2/modules/helpers/sdxl_helpers.py +122 -0
  13. {modules β†’ old2/modules}/pipelines/common_pipelines.py +0 -0
  14. {modules β†’ old2/modules}/pipelines/flux_pipelines.py +0 -0
  15. {modules β†’ old2/modules}/pipelines/sdxl_pipelines.py +0 -0
  16. {old β†’ old2/old}/app.py +0 -0
  17. {old β†’ old2/old}/app2.py +0 -0
  18. {old β†’ old2/old}/app3.py +0 -0
  19. {old β†’ old2/old}/src/tasks/images/init_sys.py +0 -0
  20. {old β†’ old2/old}/src/tasks/images/sd.py +0 -0
  21. {old β†’ old2/old}/src/ui/__init__.py +0 -0
  22. {old β†’ old2/old}/src/ui/audios.py +0 -0
  23. {old β†’ old2/old}/src/ui/images.py +0 -0
  24. {old β†’ old2/old}/src/ui/tabs/__init__.py +0 -0
  25. {old β†’ old2/old}/src/ui/tabs/images/flux.py +0 -0
  26. {old β†’ old2/old}/src/ui/talkinghead.py +0 -0
  27. {old β†’ old2/old}/src/ui/texts.py +0 -0
  28. {old β†’ old2/old}/src/ui/videos.py +0 -0
  29. {tabs β†’ old2/tabs}/audio_tab.py +0 -0
  30. {tabs β†’ old2/tabs}/image_tab.py +9 -18
  31. {tabs β†’ old2/tabs}/text_tab.py +0 -0
  32. {tabs β†’ old2/tabs}/video_tab.py +0 -0
  33. tabs/images/events.py +510 -0
  34. tabs/images/handlers.py +257 -0
  35. tabs/images/load_models.py +61 -0
  36. tabs/images/models.py +72 -0
  37. tabs/images/ui.py +179 -0
app.py CHANGED
@@ -1,36 +1,29 @@
1
  import gradio as gr
2
 
3
  from config import css
4
- from tabs.image_tab import image_tab
5
- from tabs.audio_tab import audio_tab
6
- from tabs.video_tab import video_tab
7
- from tabs.text_tab import text_tab
8
 
9
- def main():
10
- with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
11
- # Header
12
- with gr.Column(elem_classes="center-content"):
13
- gr.Markdown("""
14
- # πŸš€ AAI: All AI
15
- Unleash your creativity with our multi-modal AI platform.
16
- [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
17
- """)
18
 
19
- # Tabs
20
- with gr.Tabs():
21
- with gr.Tab(label="πŸ–ΌοΈ Image"):
22
- image_tab()
23
- with gr.Tab(label="🎡 Audio"):
24
- audio_tab()
25
- with gr.Tab(label="πŸŽ₯ Video"):
26
- video_tab()
27
- with gr.Tab(label="πŸ“ Text"):
28
- text_tab()
29
 
30
- demo.launch(
31
- share=False,
32
- debug=True,
33
- )
34
-
35
- if __name__ == "__main__":
36
- main()
 
1
  import gradio as gr
2
 
3
  from config import css
4
+ from tabs.images.ui import image_tab
 
 
 
5
 
6
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
7
+ # Header
8
+ with gr.Column(elem_classes="center-content"):
9
+ gr.Markdown("""
10
+ # πŸš€ AAI: All AI
11
+ Unleash your creativity with our multi-modal AI platform.
12
+ [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
13
+ """)
 
14
 
15
+ # Tabs
16
+ with gr.Tabs():
17
+ with gr.Tab(label="πŸ–ΌοΈ Image"):
18
+ image_tab()
19
+ # with gr.Tab(label="🎡 Audio"):
20
+ # audio_tab()
21
+ # with gr.Tab(label="πŸŽ₯ Video"):
22
+ # video_tab()
23
+ # with gr.Tab(label="πŸ“ Text"):
24
+ # text_tab()
25
 
26
+ demo.launch(
27
+ share=False,
28
+ debug=True,
29
+ )
 
 
 
config.py CHANGED
@@ -1,7 +1,9 @@
1
- # config.py
2
-
3
  import json
4
 
 
 
 
5
  css = """
6
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
7
  body {
@@ -23,14 +25,51 @@ body {
23
  """
24
 
25
 
26
- # Models
27
- flux_models = ["black-forest-labs/FLUX.1-dev"]
28
- sdxl_models = ["stabilityai/stable-diffusion-xl-base-1.0"]
29
-
30
-
31
- # Load LoRAs
32
- with open("data/loras/flux.json", "r") as f:
33
- flux_loras = json.load(f)
34
-
35
- with open("data/loras/sdxl.json", "r") as f:
36
- sdxl_loras = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
2
  import json
3
 
4
+ import torch
5
+
6
+
7
  css = """
8
  @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
9
  body {
 
25
  """
26
 
27
 
28
+ class Config:
29
+ # General
30
+ SECRET_KEY = os.environ.get('SECRET_KEY', '12345678')
31
+
32
+ # Images
33
+ # IMAGE_MODELS = ["black-forest-labs/FLUX.1-dev", "stabilityai/stable-diffusion-xl-base-1.0"]
34
+ IMAGES_MODELS = [{"repo_id": "black-forest-labs/FLUX.1-dev", "loader": "flux", "compute_type": torch.bfloat16,}, {"repo_id": "stabilityai/stable-diffusion-xl-base-1.0", "loader": "sdxl", "compute_type": torch.float16,}]
35
+ with open('data/loras/sdxl.json') as f:
36
+ IMAGES_LORAS_SDXL = json.load(f)
37
+ with open('data/loras/flux.json') as f:
38
+ IMAGES_LORAS_FLUX = json.load(f)
39
+ IMAGES_CONTROLNETS = [
40
+ {
41
+ "repo_id": "xinsir/controlnet-depth-sdxl-1.0",
42
+ "name": "depth_xl",
43
+ "layers": ["depth"],
44
+ "loader": "sdxl",
45
+ "compute_type": torch.float16,
46
+ },
47
+ {
48
+ "repo_id": "xinsir/controlnet-canny-sdxl-1.0",
49
+ "name": "canny_xl",
50
+ "layers": ["canny"],
51
+ "loader": "sdxl",
52
+ "compute_type": torch.float16,
53
+ },
54
+ {
55
+ "repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
56
+ "name": "openpose_xl",
57
+ "layers": ["pose"],
58
+ "loader": "sdxl",
59
+ "compute_type": torch.float16,
60
+ },
61
+ {
62
+ "repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
63
+ "name": "scribble_xl",
64
+ "layers": ["scribble"],
65
+ "loader": "sdxl",
66
+ "compute_type": torch.float16,
67
+ },
68
+ {
69
+ "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
70
+ "name": "flux1_union_pro",
71
+ "layers": ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"],
72
+ "loader": "flux-multi",
73
+ "compute_type": torch.bfloat16,
74
+ }
75
+ ]
modules/events/sdxl_events.py DELETED
File without changes
modules/helpers/sdxl_helpers.py DELETED
File without changes
old2/app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from config import css
4
+ from tabs.image_tab import image_tab
5
+ from tabs.audio_tab import audio_tab
6
+ from tabs.video_tab import video_tab
7
+ from tabs.text_tab import text_tab
8
+
9
+ def main():
10
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
11
+ # Header
12
+ with gr.Column(elem_classes="center-content"):
13
+ gr.Markdown("""
14
+ # πŸš€ AAI: All AI
15
+ Unleash your creativity with our multi-modal AI platform.
16
+ [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
17
+ """)
18
+
19
+ # Tabs
20
+ with gr.Tabs():
21
+ with gr.Tab(label="πŸ–ΌοΈ Image"):
22
+ image_tab()
23
+ with gr.Tab(label="🎡 Audio"):
24
+ audio_tab()
25
+ with gr.Tab(label="πŸŽ₯ Video"):
26
+ video_tab()
27
+ with gr.Tab(label="πŸ“ Text"):
28
+ text_tab()
29
+
30
+ demo.launch(
31
+ share=False,
32
+ debug=True,
33
+ )
34
+
35
+ if __name__ == "__main__":
36
+ main()
old2/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+
3
+ import json
4
+
5
+ css = """
6
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
7
+ body {
8
+ font-family: 'Poppins', sans-serif !important;
9
+ }
10
+ .center-content {
11
+ text-align: center;
12
+ max-width: 600px;
13
+ margin: 0 auto;
14
+ padding: 20px;
15
+ }
16
+ .center-content h1 {
17
+ font-weight: 600;
18
+ margin-bottom: 1rem;
19
+ }
20
+ .center-content p {
21
+ margin-bottom: 1.5rem;
22
+ }
23
+ """
24
+
25
+
26
+ # Models
27
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
28
+ sdxl_models = ["stabilityai/stable-diffusion-xl-base-1.0"]
29
+
30
+
31
+ # Load LoRAs
32
+ with open("data/loras/flux.json", "r") as f:
33
+ flux_loras = json.load(f)
34
+
35
+ with open("data/loras/sdxl.json", "r") as f:
36
+ sdxl_loras = json.load(f)
old2/modules/events/common_events.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import ModelCard
3
+
4
+ from config import Config
5
+
6
+
7
+ def selected_lora_from_gallery(evt: gr.SelectData):
8
+ return (
9
+ gr.update(
10
+ value=evt.index
11
+ )
12
+ )
13
+
14
+
15
+ def update_selected_lora(custom_lora):
16
+ link = custom_lora.split("/")
17
+
18
+ if len(link) == 2:
19
+ model_card = ModelCard.load(custom_lora)
20
+ trigger_word = model_card.data.get("instance_prompt", "")
21
+ image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
22
+
23
+ custom_lora_info_css = """
24
+ <style>
25
+ .custom-lora-info {
26
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
27
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
28
+ color: white;
29
+ padding: 16px;
30
+ border-radius: 8px;
31
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
32
+ margin: 16px 0;
33
+ }
34
+ .custom-lora-header {
35
+ font-size: 18px;
36
+ font-weight: 600;
37
+ margin-bottom: 12px;
38
+ }
39
+ .custom-lora-content {
40
+ display: flex;
41
+ align-items: center;
42
+ background-color: rgba(255, 255, 255, 0.1);
43
+ border-radius: 6px;
44
+ padding: 12px;
45
+ }
46
+ .custom-lora-image {
47
+ width: 80px;
48
+ height: 80px;
49
+ object-fit: cover;
50
+ border-radius: 6px;
51
+ margin-right: 16px;
52
+ }
53
+ .custom-lora-text h3 {
54
+ margin: 0 0 8px 0;
55
+ font-size: 16px;
56
+ font-weight: 600;
57
+ }
58
+ .custom-lora-text small {
59
+ font-size: 14px;
60
+ opacity: 0.9;
61
+ }
62
+ .custom-trigger-word {
63
+ background-color: rgba(255, 255, 255, 0.2);
64
+ padding: 2px 6px;
65
+ border-radius: 4px;
66
+ font-weight: 600;
67
+ }
68
+ </style>
69
+ """
70
+
71
+ custom_lora_info_html = f"""
72
+ <div class="custom-lora-info">
73
+ <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
74
+ <div class="custom-lora-content">
75
+ <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
76
+ <div class="custom-lora-text">
77
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
78
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
79
+ </div>
80
+ </div>
81
+ </div>
82
+ """
83
+
84
+ custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
85
+
86
+ return (
87
+ gr.update( # selected_lora
88
+ value=custom_lora,
89
+ ),
90
+ gr.update( # custom_lora_info
91
+ value=custom_lora_info_html,
92
+ visible=True
93
+ )
94
+ )
95
+
96
+ else:
97
+ return (
98
+ gr.update( # selected_lora
99
+ value=custom_lora,
100
+ ),
101
+ gr.update( # custom_lora_info
102
+ value=custom_lora_info_html if len(link) == 0 else "",
103
+ visible=False
104
+ )
105
+ )
106
+
107
+
108
+ def update_lora_sliders(enabled_loras):
109
+ sliders = []
110
+ remove_buttons = []
111
+
112
+ for lora in enabled_loras:
113
+ sliders.append(
114
+ gr.update(
115
+ label=lora.get("repo_id", ""),
116
+ info=f"Trigger Word: {lora.get('trigger_word', '')}",
117
+ visible=True,
118
+ interactive=True
119
+ )
120
+ )
121
+ remove_buttons.append(
122
+ gr.update(
123
+ visible=True,
124
+ interactive=True
125
+ )
126
+ )
127
+
128
+ if len(sliders) < 6:
129
+ for i in range(len(sliders), 6):
130
+ sliders.append(
131
+ gr.update(
132
+ visible=False
133
+ )
134
+ )
135
+ remove_buttons.append(
136
+ gr.update(
137
+ visible=False
138
+ )
139
+ )
140
+
141
+ return *sliders, *remove_buttons
142
+
143
+
144
+ def remove_from_enabled_loras(enabled_loras, index):
145
+ enabled_loras.pop(index)
146
+ return (
147
+ gr.update(
148
+ value=enabled_loras
149
+ )
150
+ )
151
+
152
+
153
+ def update_custom_embedding(custom_embedding):
154
+ link = custom_embedding.split("/")
155
+
156
+ if len(link) == 2:
157
+ model_card = ModelCard.load(custom_embedding)
158
+ trigger_word = model_card.data.get("instance_prompt", "")
159
+ image_url = f"""https://huggingface.co/{custom_embedding}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
160
+
161
+ custom_embedding_info_css = """
162
+ <style>
163
+ .custom-embedding-info {
164
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
165
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
166
+ color: white;
167
+ padding: 16px;
168
+ border-radius: 8px;
169
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
170
+ margin: 16px 0;
171
+ }
172
+ .custom-embedding-header {
173
+ font-size: 18px;
174
+ font-weight: 600;
175
+ margin-bottom: 12px;
176
+ }
177
+ .custom-embedding-content {
178
+ display: flex;
179
+ align-items: center;
180
+ background-color: rgba(255, 255, 255, 0.1);
181
+ border-radius: 6px;
182
+ padding: 12px;
183
+ }
184
+ .custom-embedding-image {
185
+ width: 80px;
186
+ height: 80px;
187
+ object-fit: cover;
188
+ border-radius: 6px;
189
+ margin-right: 16px;
190
+ }
191
+ .custom-embedding-text h3 {
192
+ margin: 0 0 8px 0;
193
+ font-size: 16px;
194
+ font-weight: 600;
195
+ }
196
+ .custom-embedding-text small {
197
+ font-size: 14px;
198
+ opacity: 0.9;
199
+ }
200
+ .custom-trigger-word {
201
+ background-color: rgba(255, 255, 255, 0.2);
202
+ padding: 2px 6px;
203
+ border-radius: 4px;
204
+ font-weight: 600;
205
+ }
206
+ </style>
207
+ """
208
+
209
+ custom_embedding_info_html = f"""
210
+ <div class="custom-embedding-info">
211
+ <div class="custom-embedding-header">Custom Embed Model: {custom_embedding}</div>
212
+ <div class="custom-embedding-content">
213
+ <img class="custom-embedding-image" src="{image_url}" alt="Embedding model preview">
214
+ <div class="custom-embedding-text">
215
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
216
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
217
+ </div>
218
+ </div>
219
+ </div>
220
+ """
221
+
222
+ custom_embedding_info_html = f"{custom_embedding_info_css}{custom_embedding_info_html}"
223
+
224
+ return gr.update( # custom_embedding_info
225
+ value=custom_embedding_info_html,
226
+ visible=True
227
+ )
228
+
229
+
230
+ def add_to_embeddings(custom_embedding, embeddings):
231
+ link = custom_embedding.split("/")
232
+
233
+ if len(link) == 2:
234
+ model_card = ModelCard.load(custom_embedding)
235
+ trigger_word = model_card.data.get("instance_prompt", "")
236
+ image_url = f"""https://huggingface.co/{custom_embedding}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
237
+
238
+ embeddings.append({
239
+ "repo_id": custom_embedding,
240
+ "trigger_word": trigger_word
241
+ })
242
+
243
+ return (
244
+ gr.update( # custom_embedding
245
+ value=""
246
+ ),
247
+ gr.update( # custom_embedding_info
248
+ value="",
249
+ visible=False
250
+ ),
251
+ gr.update( # embeddings
252
+ value=embeddings
253
+ )
254
+ )
255
+
256
+
257
+ def remove_from_embeddings(embeddings, index):
258
+ embeddings.pop(index)
259
+ return (
260
+ gr.update(
261
+ value=embeddings
262
+ )
263
+ )
264
+
{modules β†’ old2/modules}/events/flux_events.py RENAMED
@@ -1,6 +1,3 @@
1
- import json
2
- from typing import List
3
-
4
  import spaces
5
  import gradio as gr
6
  from huggingface_hub import ModelCard
@@ -34,107 +31,6 @@ def update_fast_generation(fast_generation):
34
  )
35
 
36
 
37
- def selected_lora_from_gallery(evt: gr.SelectData):
38
- return (
39
- gr.update(
40
- value=evt.index
41
- )
42
- )
43
-
44
-
45
- def update_selected_lora(custom_lora):
46
- link = custom_lora.split("/")
47
-
48
- if len(link) == 2:
49
- model_card = ModelCard.load(custom_lora)
50
- trigger_word = model_card.data.get("instance_prompt", "")
51
- image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
52
-
53
- custom_lora_info_css = """
54
- <style>
55
- .custom-lora-info {
56
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
57
- background: linear-gradient(135deg, #4a90e2, #7b61ff);
58
- color: white;
59
- padding: 16px;
60
- border-radius: 8px;
61
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
62
- margin: 16px 0;
63
- }
64
- .custom-lora-header {
65
- font-size: 18px;
66
- font-weight: 600;
67
- margin-bottom: 12px;
68
- }
69
- .custom-lora-content {
70
- display: flex;
71
- align-items: center;
72
- background-color: rgba(255, 255, 255, 0.1);
73
- border-radius: 6px;
74
- padding: 12px;
75
- }
76
- .custom-lora-image {
77
- width: 80px;
78
- height: 80px;
79
- object-fit: cover;
80
- border-radius: 6px;
81
- margin-right: 16px;
82
- }
83
- .custom-lora-text h3 {
84
- margin: 0 0 8px 0;
85
- font-size: 16px;
86
- font-weight: 600;
87
- }
88
- .custom-lora-text small {
89
- font-size: 14px;
90
- opacity: 0.9;
91
- }
92
- .custom-trigger-word {
93
- background-color: rgba(255, 255, 255, 0.2);
94
- padding: 2px 6px;
95
- border-radius: 4px;
96
- font-weight: 600;
97
- }
98
- </style>
99
- """
100
-
101
- custom_lora_info_html = f"""
102
- <div class="custom-lora-info">
103
- <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
104
- <div class="custom-lora-content">
105
- <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
106
- <div class="custom-lora-text">
107
- <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
108
- <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
109
- </div>
110
- </div>
111
- </div>
112
- """
113
-
114
- custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
115
-
116
- return (
117
- gr.update( # selected_lora
118
- value=custom_lora,
119
- ),
120
- gr.update( # custom_lora_info
121
- value=custom_lora_info_html,
122
- visible=True
123
- )
124
- )
125
-
126
- else:
127
- return (
128
- gr.update( # selected_lora
129
- value=custom_lora,
130
- ),
131
- gr.update( # custom_lora_info
132
- value=custom_lora_info_html if len(link) == 0 else "",
133
- visible=False
134
- )
135
- )
136
-
137
-
138
  def add_to_enabled_loras(selected_lora, enabled_loras):
139
  lora_data = loras
140
  try:
@@ -170,52 +66,7 @@ def add_to_enabled_loras(selected_lora, enabled_loras):
170
  )
171
 
172
 
173
- def update_lora_sliders(enabled_loras):
174
- sliders = []
175
- remove_buttons = []
176
-
177
- for lora in enabled_loras:
178
- sliders.append(
179
- gr.update(
180
- label=lora.get("repo_id", ""),
181
- info=f"Trigger Word: {lora.get('trigger_word', '')}",
182
- visible=True,
183
- interactive=True
184
- )
185
- )
186
- remove_buttons.append(
187
- gr.update(
188
- visible=True,
189
- interactive=True
190
- )
191
- )
192
-
193
- if len(sliders) < 6:
194
- for i in range(len(sliders), 6):
195
- sliders.append(
196
- gr.update(
197
- visible=False
198
- )
199
- )
200
- remove_buttons.append(
201
- gr.update(
202
- visible=False
203
- )
204
- )
205
-
206
- return *sliders, *remove_buttons
207
-
208
-
209
- def remove_from_enabled_loras(enabled_loras, index):
210
- enabled_loras.pop(index)
211
- return (
212
- gr.update(
213
- value=enabled_loras
214
- )
215
- )
216
-
217
-
218
- @spaces.GPU(duration=120)
219
  def generate_image(
220
  model, prompt, fast_generation, enabled_loras,
221
  lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  from huggingface_hub import ModelCard
 
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def add_to_enabled_loras(selected_lora, enabled_loras):
35
  lora_data = loras
36
  try:
 
66
  )
67
 
68
 
69
+ @spaces.GPU(duration=75)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def generate_image(
71
  model, prompt, fast_generation, enabled_loras,
72
  lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
old2/modules/events/sdxl_events.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from huggingface_hub import ModelCard
4
+
5
+ from modules.helpers.common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq
6
+ from modules.helpers.sdxl_helpers import gen_img
7
+ from config import sdxl_loras
8
+
9
+ loras = sdxl_loras
10
+
11
+ # Event functions
12
+ def update_fast_generation(fast_generation):
13
+ if fast_generation:
14
+ return (
15
+ gr.update(
16
+ value=0.0
17
+ ),
18
+ gr.update(
19
+ value=8
20
+ )
21
+ )
22
+ else:
23
+ return (
24
+ gr.update(
25
+ value=7.0
26
+ ),
27
+ gr.update(
28
+ value=20
29
+ )
30
+ )
31
+
32
+
33
+ def add_to_enabled_loras(selected_lora, enabled_loras):
34
+ lora_data = loras
35
+ try:
36
+ selected_lora = int(selected_lora)
37
+
38
+ if 0 <= selected_lora: # is the index of the lora in the gallery
39
+ lora_info = lora_data[selected_lora]
40
+ enabled_loras.append({
41
+ "repo_id": lora_info["repo"],
42
+ "trigger_word": lora_info["trigger_word"]
43
+ })
44
+ except ValueError:
45
+ link = selected_lora.split("/")
46
+ if len(link) == 2:
47
+ model_card = ModelCard.load(selected_lora)
48
+ trigger_word = model_card.data.get("instance_prompt", "")
49
+ enabled_loras.append({
50
+ "repo_id": selected_lora,
51
+ "trigger_word": trigger_word
52
+ })
53
+
54
+ return (
55
+ gr.update( # selected_lora
56
+ value=""
57
+ ),
58
+ gr.update( # custom_lora_info
59
+ value="",
60
+ visible=False
61
+ ),
62
+ gr.update( # enabled_loras
63
+ value=enabled_loras
64
+ )
65
+ )
66
+
67
+
68
+ @spaces.GPU(duration=75)
69
+ def generate_image(
70
+ model, prompt, negative_prompt, fast_generation, enabled_loras, enabled_embeddings, # type: ignore
71
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
72
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
73
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
74
+ resize_mode,
75
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
76
+ image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
77
+ refiner, vae
78
+ ):
79
+ try:
80
+ base_args = {
81
+ "model": model,
82
+ "prompt": prompt,
83
+ "negative_prompt": negative_prompt,
84
+ "fast_generation": fast_generation,
85
+ "loras": None,
86
+ "embeddings": None,
87
+ "resize_mode": resize_mode,
88
+ "scheduler": scheduler,
89
+ "height": image_height,
90
+ "width": image_width,
91
+ "num_images_per_prompt": image_num_images_per_prompt,
92
+ "num_inference_steps": image_num_inference_steps,
93
+ "clip_skip": image_clip_skip,
94
+ "guidance_scale": image_guidance_scale,
95
+ "seed": image_seed,
96
+ "refiner": refiner,
97
+ "vae": vae,
98
+ "controlnet_config": None,
99
+ }
100
+ base_args = BaseReq(**base_args)
101
+
102
+ if len(enabled_loras) > 0:
103
+ base_args.loras = []
104
+ for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
105
+ if enabled_lora["repo_id"]:
106
+ base_args.loras.append({
107
+ "repo_id": enabled_lora["repo_id"],
108
+ "weight": slider
109
+ })
110
+
111
+ if len(enabled_embeddings) > 0:
112
+ base_args.embeddings = enabled_embeddings
113
+
114
+ image = None
115
+ mask_image = None
116
+ strength = None
117
+
118
+ if img2img_image:
119
+ image = img2img_image
120
+ strength = float(img2img_strength)
121
+
122
+ base_args = BaseImg2ImgReq(
123
+ **base_args.__dict__,
124
+ image=image,
125
+ strength=strength
126
+ )
127
+ elif inpaint_image:
128
+ image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
129
+ mask_image = inpaint_image['layers'][0] if image else None
130
+ strength = float(inpaint_strength)
131
+
132
+ if image and mask_image:
133
+ base_args = BaseInpaintReq(
134
+ **base_args.__dict__,
135
+ image=image,
136
+ mask_image=mask_image,
137
+ strength=strength
138
+ )
139
+ elif any([canny_image, pose_image, depth_image]):
140
+ base_args.controlnet_config = ControlNetReq(
141
+ controlnets=[],
142
+ control_images=[],
143
+ controlnet_conditioning_scale=[]
144
+ )
145
+
146
+ if canny_image:
147
+ base_args.controlnet_config.controlnets.append("canny")
148
+ base_args.controlnet_config.control_images.append(canny_image)
149
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
150
+ if pose_image:
151
+ base_args.controlnet_config.controlnets.append("pose")
152
+ base_args.controlnet_config.control_images.append(pose_image)
153
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
154
+ if depth_image:
155
+ base_args.controlnet_config.controlnets.append("depth")
156
+ base_args.controlnet_config.control_images.append(depth_image)
157
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
158
+ if scribble_image:
159
+ base_args.controlnet_config.controlnets.append("scribble")
160
+ base_args.controlnet_config.control_images.append(scribble_image)
161
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(scribble_strength))
162
+ else:
163
+ base_args = BaseReq(**base_args.__dict__)
164
+
165
+ return gr.update(
166
+ value=gen_img(base_args),
167
+ interactive=True
168
+ )
169
+ except Exception as e:
170
+ raise gr.Error(f"Error: {e}") from e
{modules β†’ old2/modules}/helpers/common_helpers.py RENAMED
@@ -20,14 +20,17 @@ class ControlNetReq(BaseModel):
20
  class BaseReq(BaseModel):
21
  model: str = ""
22
  prompt: str = ""
 
23
  fast_generation: Optional[bool] = True
24
  loras: Optional[list] = []
 
25
  resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
26
  scheduler: Optional[str] = "euler_fl"
27
  height: int = 1024
28
  width: int = 1024
29
  num_images_per_prompt: int = 1
30
  num_inference_steps: int = 8
 
31
  guidance_scale: float = 3.5
32
  seed: Optional[int] = 0
33
  refiner: bool = False
 
20
  class BaseReq(BaseModel):
21
  model: str = ""
22
  prompt: str = ""
23
+ negative_prompt: Optional[str] = ""
24
  fast_generation: Optional[bool] = True
25
  loras: Optional[list] = []
26
+ embeddings: Optional[list] = None
27
  resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
28
  scheduler: Optional[str] = "euler_fl"
29
  height: int = 1024
30
  width: int = 1024
31
  num_images_per_prompt: int = 1
32
  num_inference_steps: int = 8
33
+ clip_skip: Optional[int] = None
34
  guidance_scale: float = 3.5
35
  seed: Optional[int] = 0
36
  refiner: bool = False
{modules β†’ old2/modules}/helpers/flux_helpers.py RENAMED
@@ -48,7 +48,7 @@ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
48
  elif isinstance(request, BaseReq):
49
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
50
 
51
- # Enable or Disable Refiner
52
  if request.vae:
53
  pipe_args["pipeline"].vae = flux_vae
54
  elif not request.vae:
 
48
  elif isinstance(request, BaseReq):
49
  pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
50
 
51
+ # Enable or Disable Vae
52
  if request.vae:
53
  pipe_args["pipeline"].vae = flux_vae
54
  elif not request.vae:
old2/modules/helpers/sdxl_helpers.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import (
6
+ AutoPipelineForText2Image,
7
+ AutoPipelineForImage2Image,
8
+ AutoPipelineForInpainting,
9
+ )
10
+ from huggingface_hub import hf_hub_download
11
+ from diffusers.schedulers import *
12
+ # from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
13
+
14
+ from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
15
+ from modules.pipelines.sdxl_pipelines import device, models, sdxl_vae, controlnets
16
+ from modules.pipelines.common_pipelines import refiner
17
+
18
+
19
+ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
20
+ def get_scheduler(pipeline, scheduler: str):
21
+ ...
22
+
23
+ for m in models:
24
+ if m['repo_id'] == request.model:
25
+ pipe_args = {
26
+ "pipeline": m['pipeline'],
27
+ }
28
+
29
+ # Set ControlNet config
30
+ if request.controlnet_config:
31
+ pipe_args["controlnet"] = [controlnets]
32
+
33
+ # Choose Pipeline Mode
34
+ if isinstance(request, BaseInpaintReq):
35
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
36
+ elif isinstance(request, BaseImg2ImgReq):
37
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
38
+ elif isinstance(request, BaseReq):
39
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
40
+
41
+ # Enable or Disable Refiner
42
+ if request.vae:
43
+ pipe_args["pipeline"].vae = sdxl_vae
44
+ elif not request.vae:
45
+ pipe_args["pipeline"].vae = None
46
+
47
+ # Set Scheduler
48
+ pipe_args["pipeline"].scheduler = get_scheduler(pipe_args["pipeline"], request.scheduler)
49
+
50
+ # Set Loras
51
+ if request.loras:
52
+ for i, lora in enumerate(request.loras):
53
+ pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
54
+ adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
55
+ adapter_weights = [lora['weight'] for lora in request.loras]
56
+
57
+ if request.fast_generation:
58
+ hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
59
+ hyper_weight = 0.125
60
+ pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
61
+ adapter_names.append("hyper_lora")
62
+ adapter_weights.append(hyper_weight)
63
+
64
+ pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
65
+
66
+ # Set Embeddings
67
+ if request.embeddings:
68
+ ...
69
+
70
+ return pipe_args
71
+
72
+
73
+ def get_prompt_attention(pipeline, prompt):
74
+ return get_weighted_text_embeddings_flux1(pipeline, prompt)
75
+
76
+
77
+ # Gen Function
78
+ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
79
+ pipe_args = get_pipe(request)
80
+ pipeline = pipe_args["pipeline"]
81
+ try:
82
+ positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
83
+
84
+ # Common Args
85
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ args = {
87
+ 'prompt_embeds': positive_prompt_embeds,
88
+ 'pooled_prompt_embeds': positive_prompt_pooled,
89
+ 'height': request.height,
90
+ 'width': request.width,
91
+ 'num_images_per_prompt': request.num_images_per_prompt,
92
+ 'num_inference_steps': request.num_inference_steps,
93
+ 'clip_skip': request.clip_skip,
94
+ 'guidance_scale': request.guidance_scale,
95
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
96
+ }
97
+
98
+ if request.controlnet_config:
99
+ args['control_mode'] = get_control_mode(request.controlnet_config)
100
+ args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
101
+ args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
102
+
103
+ if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
104
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
105
+ args['strength'] = request.strength
106
+
107
+ if isinstance(request, BaseInpaintReq):
108
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
109
+
110
+ # Generate
111
+ images = pipeline(**args).images
112
+
113
+ # Refiner
114
+ if request.refiner:
115
+ images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
116
+
117
+ return images
118
+ except Exception as e:
119
+ cleanup(pipeline, request.loras)
120
+ raise gr.Error(f"Error: {e}")
121
+ finally:
122
+ cleanup(pipeline, request.loras)
{modules β†’ old2/modules}/pipelines/common_pipelines.py RENAMED
File without changes
{modules β†’ old2/modules}/pipelines/flux_pipelines.py RENAMED
File without changes
{modules β†’ old2/modules}/pipelines/sdxl_pipelines.py RENAMED
File without changes
{old β†’ old2/old}/app.py RENAMED
File without changes
{old β†’ old2/old}/app2.py RENAMED
File without changes
{old β†’ old2/old}/app3.py RENAMED
File without changes
{old β†’ old2/old}/src/tasks/images/init_sys.py RENAMED
File without changes
{old β†’ old2/old}/src/tasks/images/sd.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/__init__.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/audios.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/images.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/tabs/__init__.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/tabs/images/flux.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/talkinghead.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/texts.py RENAMED
File without changes
{old β†’ old2/old}/src/ui/videos.py RENAMED
File without changes
{tabs β†’ old2/tabs}/audio_tab.py RENAMED
File without changes
{tabs β†’ old2/tabs}/image_tab.py RENAMED
@@ -1,7 +1,9 @@
1
  # tabs/image_tab.py
 
2
 
3
  import gradio as gr
4
  from modules.helpers.common_helpers import *
 
5
 
6
 
7
  def image_tab():
@@ -15,11 +17,7 @@ def image_tab():
15
  def flux_tab():
16
  from modules.events.flux_events import (
17
  update_fast_generation,
18
- selected_lora_from_gallery,
19
- update_selected_lora,
20
  add_to_enabled_loras,
21
- update_lora_sliders,
22
- remove_from_enabled_loras,
23
  generate_image
24
  )
25
  from config import flux_models, flux_loras
@@ -188,15 +186,7 @@ def flux_tab():
188
  def sdxl_tab():
189
  from modules.events.sdxl_events import (
190
  update_fast_generation,
191
- selected_lora_from_gallery,
192
- update_selected_lora,
193
  add_to_enabled_loras,
194
- update_lora_sliders,
195
- remove_from_enabled_loras,
196
- add_to_embeddings,
197
- update_custom_embedding,
198
- remove_from_embeddings,
199
- generate_image
200
  )
201
  from config import sdxl_models, sdxl_loras
202
 
@@ -206,6 +196,7 @@ def sdxl_tab():
206
  with gr.Group() as image_options:
207
  model = gr.Dropdown(label="Models", choices=sdxl_models, value=sdxl_models[0], interactive=True)
208
  prompt = gr.Textbox(lines=5, label="Prompt")
 
209
  fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
210
 
211
 
@@ -242,7 +233,7 @@ def sdxl_tab():
242
  custom_embedding = gr.Textbox(label="Custom Embedding")
243
  custom_embedding_info = gr.HTML(visible=False)
244
  add_embedding = gr.Button(value="Add Embedding")
245
- embeddings = gr.State(value=[])
246
  with gr.Group():
247
  with gr.Row():
248
  for i in range(6):
@@ -357,19 +348,19 @@ def sdxl_tab():
357
 
358
  # Embeddings
359
  custom_embedding.change(update_custom_embedding, custom_embedding, [custom_embedding_info])
360
- add_embedding.click(add_to_embeddings, [custom_embedding, embeddings], [custom_embedding, custom_embedding_info, embeddings])
361
  for i in range(6):
362
  globals()[f"embedding_remove_{i}"].click(
363
- lambda embeddings, index=i: remove_from_embeddings(embeddings, index),
364
- [embeddings],
365
- [embeddings]
366
  )
367
 
368
  # Generate Image
369
  generate_images.click(
370
  generate_image, # type: ignore
371
  [
372
- model, prompt, fast_generation, enabled_loras,
373
  lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
374
  img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
375
  img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
 
1
  # tabs/image_tab.py
2
+ import random
3
 
4
  import gradio as gr
5
  from modules.helpers.common_helpers import *
6
+ from modules.events.common_events import *
7
 
8
 
9
  def image_tab():
 
17
  def flux_tab():
18
  from modules.events.flux_events import (
19
  update_fast_generation,
 
 
20
  add_to_enabled_loras,
 
 
21
  generate_image
22
  )
23
  from config import flux_models, flux_loras
 
186
  def sdxl_tab():
187
  from modules.events.sdxl_events import (
188
  update_fast_generation,
 
 
189
  add_to_enabled_loras,
 
 
 
 
 
 
190
  )
191
  from config import sdxl_models, sdxl_loras
192
 
 
196
  with gr.Group() as image_options:
197
  model = gr.Dropdown(label="Models", choices=sdxl_models, value=sdxl_models[0], interactive=True)
198
  prompt = gr.Textbox(lines=5, label="Prompt")
199
+ negative_prompt = gr.Textbox(lines=5, label="Negative Prompt")
200
  fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
201
 
202
 
 
233
  custom_embedding = gr.Textbox(label="Custom Embedding")
234
  custom_embedding_info = gr.HTML(visible=False)
235
  add_embedding = gr.Button(value="Add Embedding")
236
+ enabled_embeddings = gr.State(value=[])
237
  with gr.Group():
238
  with gr.Row():
239
  for i in range(6):
 
348
 
349
  # Embeddings
350
  custom_embedding.change(update_custom_embedding, custom_embedding, [custom_embedding_info])
351
+ add_embedding.click(add_to_embeddings, [custom_embedding, enabled_embeddings], [custom_embedding, custom_embedding_info, enabled_embeddings])
352
  for i in range(6):
353
  globals()[f"embedding_remove_{i}"].click(
354
+ lambda enabled_embeddings, index=i: remove_from_embeddings(enabled_embeddings, index),
355
+ [enabled_embeddings],
356
+ [enabled_embeddings]
357
  )
358
 
359
  # Generate Image
360
  generate_images.click(
361
  generate_image, # type: ignore
362
  [
363
+ model, prompt, negative_prompt, fast_generation, enabled_loras, enabled_embeddings,
364
  lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
365
  img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
366
  img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
{tabs β†’ old2/tabs}/text_tab.py RENAMED
File without changes
{tabs β†’ old2/tabs}/video_tab.py RENAMED
File without changes
tabs/images/events.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from huggingface_hub import ModelCard
4
+
5
+ from config import Config
6
+ from .models import *
7
+ from .handlers import gen_img
8
+
9
+ # Common
10
+ def update_model_options(model):
11
+ for m in Config.IMAGES_MODELS:
12
+ if m['repo_id'] == model:
13
+ if m['loader'] == 'flux':
14
+ return (
15
+ gr.update( # negative_prompt
16
+ visible=False
17
+ ),
18
+ gr.update( # lora_gallery
19
+ value=Config.IMAGES_LORAS_FLUX
20
+ ),
21
+ gr.update( # embeddings_accordion
22
+ visible=False
23
+ ),
24
+ gr.update( # scribble_tab
25
+ visible=False
26
+ ),
27
+ gr.update( # scheduler
28
+ value='fm_euler'
29
+ ),
30
+ gr.update( # image_clip_skip
31
+ visible=False
32
+ ),
33
+ gr.update( # image_guidance_scale
34
+ value=3.5
35
+ )
36
+ )
37
+
38
+ elif m['loader'] == 'sdxl':
39
+ return (
40
+ gr.update( # negative_prompt
41
+ visible=True
42
+ ),
43
+ gr.update( # lora_gallery
44
+ value=Config.IMAGES_LORAS_SDXL
45
+ ),
46
+ gr.update( # embeddings_accordion
47
+ visible=True
48
+ ),
49
+ gr.update( # scribble_tab
50
+ visible=True
51
+ ),
52
+ gr.update( # scheduler
53
+ value='dpmpp_2m_sde_k'
54
+ ),
55
+ gr.update( # image_clip_skip
56
+ visible=True
57
+ ),
58
+ gr.update( # image_guidance_scale
59
+ value=7.0
60
+ )
61
+ )
62
+
63
+
64
+ def update_fast_generation(model, fast_generation):
65
+ for m in Config.IMAGES_MODELS:
66
+ if m['repo_id'] == model:
67
+ if m['loader'] == 'flux':
68
+ if fast_generation:
69
+ return (
70
+ gr.update( # image_num_inference_steps
71
+ value=8
72
+ ),
73
+ gr.update( # image_guidance_scale
74
+ value=3.5
75
+ )
76
+ )
77
+ else:
78
+ return (
79
+ gr.update( # image_num_inference_steps
80
+ value=20
81
+ ),
82
+ gr.update( # image_guidance_scale
83
+ value=3.5
84
+ )
85
+ )
86
+ elif m['loader'] == 'sdxl':
87
+ if fast_generation:
88
+ return (
89
+ gr.update( # image_num_inference_steps
90
+ value=8
91
+ ),
92
+ gr.update( # image_guidance_scale
93
+ value=1.0
94
+ )
95
+ )
96
+ else:
97
+ return (
98
+ gr.update( # image_num_inference_steps
99
+ value=20
100
+ ),
101
+ gr.update( # image_guidance_scale
102
+ value=7.0
103
+ )
104
+ )
105
+
106
+
107
+ # Loras
108
+ def selected_lora_from_gallery(evt: gr.SelectData):
109
+ return (
110
+ gr.update(
111
+ value=evt.index
112
+ )
113
+ )
114
+
115
+
116
+ def update_selected_lora(custom_lora):
117
+ link = custom_lora.split("/")
118
+
119
+ if len(link) == 2:
120
+ model_card = ModelCard.load(custom_lora)
121
+ trigger_word = model_card.data.get("instance_prompt", "")
122
+ image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
123
+
124
+ custom_lora_info_css = """
125
+ <style>
126
+ .custom-lora-info {
127
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
128
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
129
+ color: white;
130
+ padding: 16px;
131
+ border-radius: 8px;
132
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
133
+ margin: 16px 0;
134
+ }
135
+ .custom-lora-header {
136
+ font-size: 18px;
137
+ font-weight: 600;
138
+ margin-bottom: 12px;
139
+ }
140
+ .custom-lora-content {
141
+ display: flex;
142
+ align-items: center;
143
+ background-color: rgba(255, 255, 255, 0.1);
144
+ border-radius: 6px;
145
+ padding: 12px;
146
+ }
147
+ .custom-lora-image {
148
+ width: 80px;
149
+ height: 80px;
150
+ object-fit: cover;
151
+ border-radius: 6px;
152
+ margin-right: 16px;
153
+ }
154
+ .custom-lora-text h3 {
155
+ margin: 0 0 8px 0;
156
+ font-size: 16px;
157
+ font-weight: 600;
158
+ }
159
+ .custom-lora-text small {
160
+ font-size: 14px;
161
+ opacity: 0.9;
162
+ }
163
+ .custom-trigger-word {
164
+ background-color: rgba(255, 255, 255, 0.2);
165
+ padding: 2px 6px;
166
+ border-radius: 4px;
167
+ font-weight: 600;
168
+ }
169
+ </style>
170
+ """
171
+
172
+ custom_lora_info_html = f"""
173
+ <div class="custom-lora-info">
174
+ <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
175
+ <div class="custom-lora-content">
176
+ <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
177
+ <div class="custom-lora-text">
178
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
179
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
180
+ </div>
181
+ </div>
182
+ </div>
183
+ """
184
+
185
+ custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"
186
+
187
+ return (
188
+ gr.update( # selected_lora
189
+ value=custom_lora,
190
+ ),
191
+ gr.update( # custom_lora_info
192
+ value=custom_lora_info_html,
193
+ visible=True
194
+ )
195
+ )
196
+
197
+ else:
198
+ return (
199
+ gr.update( # selected_lora
200
+ value=custom_lora,
201
+ ),
202
+ gr.update( # custom_lora_info
203
+ value=custom_lora_info_html if len(link) == 0 else "",
204
+ visible=False
205
+ )
206
+ )
207
+
208
+
209
+ def update_lora_sliders(enabled_loras):
210
+ sliders = []
211
+ remove_buttons = []
212
+
213
+ for lora in enabled_loras:
214
+ sliders.append(
215
+ gr.update(
216
+ label=lora.get("repo_id", ""),
217
+ info=f"Trigger Word: {lora.get('trigger_word', '')}",
218
+ visible=True,
219
+ interactive=True
220
+ )
221
+ )
222
+ remove_buttons.append(
223
+ gr.update(
224
+ visible=True,
225
+ interactive=True
226
+ )
227
+ )
228
+
229
+ if len(sliders) < 6:
230
+ for i in range(len(sliders), 6):
231
+ sliders.append(
232
+ gr.update(
233
+ visible=False
234
+ )
235
+ )
236
+ remove_buttons.append(
237
+ gr.update(
238
+ visible=False
239
+ )
240
+ )
241
+
242
+ return *sliders, *remove_buttons
243
+
244
+
245
+ def remove_from_enabled_loras(enabled_loras, index):
246
+ enabled_loras.pop(index)
247
+ return (
248
+ gr.update(
249
+ value=enabled_loras
250
+ )
251
+ )
252
+
253
+
254
+ def add_to_enabled_loras(model, selected_lora, enabled_loras):
255
+
256
+ for m in Config.IMAGES_MODELS:
257
+ if m['repo_id'] == model.value:
258
+ lora_data = []
259
+ if m['loader'] == 'flux':
260
+ lora_data = Config.IMAGES_LORAS_FLUX
261
+ elif m['loader'] == 'sdxl':
262
+ lora_data = Config.IMAGES_LORAS_SDXL
263
+ try:
264
+ selected_lora = int(selected_lora)
265
+
266
+ if 0 <= selected_lora: # is the index of the lora in the gallery
267
+ lora_info = lora_data[selected_lora]
268
+ enabled_loras.append({
269
+ "repo_id": lora_info["repo"],
270
+ "trigger_word": lora_info["trigger_word"]
271
+ })
272
+ except ValueError:
273
+ link = selected_lora.split("/")
274
+ if len(link) == 2:
275
+ model_card = ModelCard.load(selected_lora)
276
+ trigger_word = model_card.data.get("instance_prompt", "")
277
+ enabled_loras.append({
278
+ "repo_id": selected_lora,
279
+ "trigger_word": trigger_word
280
+ })
281
+
282
+ return (
283
+ gr.update( # selected_lora
284
+ value=""
285
+ ),
286
+ gr.update( # custom_lora_info
287
+ value="",
288
+ visible=False
289
+ ),
290
+ gr.update( # enabled_loras
291
+ value=enabled_loras
292
+ )
293
+ )
294
+
295
+
296
+ # Custom Embedding
297
+ def update_custom_embedding(custom_embedding):
298
+ link = custom_embedding.split("/")
299
+
300
+ if len(link) == 2:
301
+ model_card = ModelCard.load(custom_embedding)
302
+ trigger_word = model_card.data.get("instance_prompt", "")
303
+ image_url = f"""https://huggingface.co/{custom_embedding}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
304
+
305
+ custom_embedding_info_css = """
306
+ <style>
307
+ .custom-embedding-info {
308
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
309
+ background: linear-gradient(135deg, #4a90e2, #7b61ff);
310
+ color: white;
311
+ padding: 16px;
312
+ border-radius: 8px;
313
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
314
+ margin: 16px 0;
315
+ }
316
+ .custom-embedding-header {
317
+ font-size: 18px;
318
+ font-weight: 600;
319
+ margin-bottom: 12px;
320
+ }
321
+ .custom-embedding-content {
322
+ display: flex;
323
+ align-items: center;
324
+ background-color: rgba(255, 255, 255, 0.1);
325
+ border-radius: 6px;
326
+ padding: 12px;
327
+ }
328
+ .custom-embedding-image {
329
+ width: 80px;
330
+ height: 80px;
331
+ object-fit: cover;
332
+ border-radius: 6px;
333
+ margin-right: 16px;
334
+ }
335
+ .custom-embedding-text h3 {
336
+ margin: 0 0 8px 0;
337
+ font-size: 16px;
338
+ font-weight: 600;
339
+ }
340
+ .custom-embedding-text small {
341
+ font-size: 14px;
342
+ opacity: 0.9;
343
+ }
344
+ .custom-trigger-word {
345
+ background-color: rgba(255, 255, 255, 0.2);
346
+ padding: 2px 6px;
347
+ border-radius: 4px;
348
+ font-weight: 600;
349
+ }
350
+ </style>
351
+ """
352
+
353
+ custom_embedding_info_html = f"""
354
+ <div class="custom-embedding-info">
355
+ <div class="custom-embedding-header">Custom Embedding: {custom_embedding}</div>
356
+ <div class="custom-embedding-content">
357
+ <img class="custom-embedding-image" src="{image_url}" alt="Embedding preview">
358
+ <div class="custom-embedding-text">
359
+ <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
360
+ <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
361
+ </div>
362
+ </div>
363
+ </div>
364
+ """
365
+
366
+ custom_embedding_info_html = f"{custom_embedding_info_css}{custom_embedding_info_html}"
367
+
368
+ return gr.update(value=custom_embedding_info_html, visible=True)
369
+ else:
370
+ return gr.update(value="", visible=False)
371
+
372
+
373
+ def add_to_embeddings(custom_embedding, enabled_embeddings):
374
+ link = custom_embedding.split("/")
375
+ if len(link) == 2:
376
+ if ModelCard.load(custom_embedding):
377
+ enabled_embeddings.append(custom_embedding)
378
+
379
+ return (
380
+ gr.update( # custom_embedding
381
+ value=""
382
+ ),
383
+ gr.update( # custom_embedding_info
384
+ value="",
385
+ visible=False
386
+ ),
387
+ gr.update( # enabled_embeddings
388
+ value=enabled_embeddings
389
+ )
390
+ )
391
+
392
+
393
+ def remove_from_embeddings(enabled_embeddings, index):
394
+ enabled_embeddings.pop(index)
395
+ return (
396
+ gr.update(
397
+ value=enabled_embeddings
398
+ )
399
+ )
400
+
401
+
402
+ # Generate Image
403
+ @spaces.GPU(duration=75)
404
+ def generate_image(
405
+ model, prompt, negative_prompt, fast_generation, enabled_loras, enabled_embeddings,
406
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
407
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
408
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
409
+ resize_mode,
410
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
411
+ image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
412
+ refiner, vae
413
+ ):
414
+ try:
415
+ base_args = {
416
+ "model": model,
417
+ "prompt": prompt,
418
+ # "negative_prompt": negative_prompt,
419
+ "fast_generation": fast_generation,
420
+ "loras": None,
421
+ # "embeddings": None,
422
+ "resize_mode": resize_mode,
423
+ "scheduler": scheduler,
424
+ "height": int(image_height),
425
+ "width": int(image_width),
426
+ "num_images_per_prompt": float(image_num_images_per_prompt),
427
+ "num_inference_steps": float(image_num_inference_steps),
428
+ # "clip_skip": None,
429
+ "guidance_scale": None,
430
+ "seed": int(image_seed),
431
+ "refiner": refiner,
432
+ "vae": vae,
433
+ "controlnet_config": None,
434
+ }
435
+ base_args = BaseReq(**base_args)
436
+
437
+ if len(enabled_loras) > 0:
438
+ base_args.loras = []
439
+ for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
440
+ if enabled_lora['repo_id']:
441
+ base_args.loras.append({
442
+ "repo_id": enabled_lora['repo_id'],
443
+ "weight": slider
444
+ })
445
+
446
+ # Load SDXL related args
447
+ if model in Config.IMAGES_MODELS:
448
+ if model['loader'] == 'sdxl':
449
+ base_args.negative_prompt = negative_prompt
450
+ base_args.clip_skip = image_clip_skip
451
+ if len(enabled_embeddings) > 0:
452
+ base_args.embeddings = enabled_embeddings
453
+
454
+ image = None
455
+ mask_image = None
456
+ strength = None
457
+
458
+ if img2img_image:
459
+ image = img2img_image
460
+ strength = float(img2img_strength)
461
+
462
+ base_args = BaseImg2ImgReq(
463
+ **base_args.__dict__,
464
+ image=image,
465
+ strength=strength
466
+ )
467
+ elif inpaint_image:
468
+ image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
469
+ mask_image = inpaint_image['layers'][0] if image else None
470
+ strength = float(inpaint_strength)
471
+
472
+ if image and mask_image:
473
+ base_args = BaseInpaintReq(
474
+ **base_args.__dict__,
475
+ image=image,
476
+ mask_image=mask_image,
477
+ strength=strength
478
+ )
479
+ elif any([canny_image, pose_image, depth_image]):
480
+ base_args.controlnet_config = ControlNetReq(
481
+ controlnets=[],
482
+ control_images=[],
483
+ controlnet_conditioning_scale=[]
484
+ )
485
+
486
+ if canny_image:
487
+ base_args.controlnet_config.controlnets.append("canny")
488
+ base_args.controlnet_config.control_images.append(canny_image)
489
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
490
+ if pose_image:
491
+ base_args.controlnet_config.controlnets.append("pose")
492
+ base_args.controlnet_config.control_images.append(pose_image)
493
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
494
+ if depth_image:
495
+ base_args.controlnet_config.controlnets.append("depth")
496
+ base_args.controlnet_config.control_images.append(depth_image)
497
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
498
+ if model in Config.IMAGES_MODELS and model['loader'] == 'sdxl' and scribble_image:
499
+ base_args.controlnet_config.controlnets.append("scribble")
500
+ base_args.controlnet_config.control_images.append(scribble_image)
501
+ base_args.controlnet_config.controlnet_conditioning_scale.append(float(scribble_strength))
502
+ else:
503
+ base_args = BaseReq(**base_args.__dict__)
504
+
505
+ return gr.update(
506
+ value=gen_img(base_args),
507
+ interactive=True
508
+ )
509
+ except Exception as e:
510
+ raise gr.Error(f"Error: {e}") from e
tabs/images/handlers.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import random
3
+
4
+ import gradio as gr
5
+ import torch
6
+ from controlnet_aux import Processor
7
+ from safetensors.torch import load_file
8
+ from diffusers import (
9
+ AutoPipelineForText2Image,
10
+ AutoPipelineForImage2Image,
11
+ AutoPipelineForInpainting,
12
+ FluxPipeline,
13
+ FluxImg2ImgPipeline,
14
+ FluxInpaintPipeline,
15
+ FluxControlNetPipeline,
16
+ StableDiffusionXLPipeline,
17
+ StableDiffusionXLImg2ImgPipeline,
18
+ StableDiffusionXLInpaintPipeline,
19
+ StableDiffusionXLControlNetPipeline,
20
+ StableDiffusionXLControlNetImg2ImgPipeline,
21
+ StableDiffusionXLControlNetInpaintPipeline,
22
+ )
23
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_sdxl
24
+ from huggingface_hub import hf_hub_download
25
+ from diffusers.schedulers import *
26
+
27
+ from .models import *
28
+ from .load_models import device, models, flux_vae, sdxl_vae, refiner, controlnets
29
+
30
+ sd_pipes = (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
31
+ StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline)
32
+ flux_pipes = (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)
33
+
34
+
35
+ def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
36
+ for model in models:
37
+ if model['repo_id'] == request.model:
38
+ pipe_args = {
39
+ "pipeline": model['pipeline'],
40
+ }
41
+
42
+ # Set ControlNet config
43
+ if request.controlnet_config:
44
+ pipe_args["controlnet"] = []
45
+ if model['loader'] == 'sdxl' or model['loader'] == 'flux':
46
+ for controlnet in controlnets:
47
+ if request.controlnet_config.controlnet in controlnet['layers']:
48
+ pipe_args["controlnet"].append(controlnet['controlnet'])
49
+ elif model['loader'] == 'flux-multi':
50
+ controlnet = next((controlnet for controlnet in controlnets if controlnet['loader'] == 'flux-multi'), None)
51
+ if controlnet is not None:
52
+ # control_mode = list of index of layers
53
+ pipe_args['control_mode'] = [controlnet['layers'].index(layer) for layer in request.controlnet_config.controlnet]
54
+ pipe_args['controlnet'].append(controlnet['controlnet'])
55
+
56
+ # Choose Pipeline Mode
57
+ if not request.custom_addons:
58
+ if isinstance(request, BaseInpaintReq):
59
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
60
+ elif isinstance(request, BaseImg2ImgReq):
61
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
62
+ elif isinstance(request, BaseReq):
63
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
64
+ elif request.custom_addons:
65
+ ...
66
+
67
+ # Enable or Disable Vae
68
+ if request.vae:
69
+ pipe_args["pipeline"].vae = sdxl_vae if model['loader'] == 'sdxl' else flux_vae
70
+ elif not request.vae:
71
+ pipe_args["pipeline"].vae = None
72
+
73
+ # Set Scheduler
74
+ pipe_args["pipeline"].scheduler = get_scheduler(pipe_args["pipeline"], request.scheduler)
75
+
76
+ # Set Loras
77
+ if request.loras:
78
+ for i, lora in enumerate(request.loras):
79
+ pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
80
+ adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
81
+ adapter_weights = [lora['weight'] for lora in request.loras]
82
+
83
+ if request.fast_generation:
84
+ hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") if model['loader'] == 'flux' \
85
+ else hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-8steps-lora.safetensors")
86
+ hyper_weight = 0.125 if model['loader'] == 'flux' else 1.0
87
+ pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
88
+ pipe_args["pipeline"].set_adapters(["hyper_lora"], [hyper_weight])
89
+
90
+ pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
91
+
92
+ # Set Embeddings
93
+ if request.embeddings and model['loader'] == 'sdxl':
94
+ for embedding in request.embeddings:
95
+ state_dict = load_file(hf_hub_download(embedding['repo_id']))
96
+ pipe_args["pipeline"].load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder_2, tokenizer=pipe_args["pipeline"].tokenizer_2)
97
+ pipe_args["pipeline"].load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder, tokenizer=pipe_args["pipeline"].tokenizer)
98
+
99
+ return pipe_args
100
+
101
+
102
+ def load_scheduler(pipeline, scheduler):
103
+ schedulers = {
104
+ "dpmpp_2m": (DPMSolverMultistepScheduler, {}),
105
+ "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
106
+ "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
107
+ "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
108
+ "dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
109
+ "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
110
+ "dpm2": (KDPM2DiscreteScheduler, {}),
111
+ "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
112
+ "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
113
+ "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
114
+ "euler": (EulerDiscreteScheduler, {}),
115
+ "euler_a": (EulerAncestralDiscreteScheduler, {}),
116
+ "heun": (HeunDiscreteScheduler, {}),
117
+ "lms": (LMSDiscreteScheduler, {}),
118
+ "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
119
+ "deis": (DEISMultistepScheduler, {}),
120
+ "unipc": (UniPCMultistepScheduler, {}),
121
+ "fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
122
+ }
123
+ scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
124
+
125
+ if scheduler_class is not None:
126
+ scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
127
+ else:
128
+ raise ValueError(f"Unknown scheduler: {scheduler}")
129
+
130
+ return scheduler
131
+
132
+
133
+ def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
134
+ for image in images:
135
+ if resize_mode == "resize_only":
136
+ image = image.resize((width, height))
137
+ elif resize_mode == "crop_and_resize":
138
+ image = image.crop((0, 0, width, height))
139
+ elif resize_mode == "resize_and_fill":
140
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
141
+
142
+ return images
143
+
144
+
145
+ def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
146
+ response_images = []
147
+ control_images = resize_images(control_images, height, width, resize_mode)
148
+ for controlnet, image in zip(controlnets, control_images):
149
+ if controlnet == "canny":
150
+ processor = Processor('canny')
151
+ elif controlnet == "depth":
152
+ processor = Processor('depth_midas')
153
+ elif controlnet == "pose":
154
+ processor = Processor('openpose_full')
155
+ elif controlnet == "scribble":
156
+ processor = Processor('scribble')
157
+ else:
158
+ raise ValueError(f"Invalid Controlnet: {controlnet}")
159
+
160
+ response_images.append(processor(image, to_pil=True))
161
+
162
+ return response_images
163
+
164
+
165
+ def get_control_mode(controlnet_config: ControlNetReq):
166
+ control_mode = []
167
+ for controlnet in controlnets:
168
+ if controlnet['loader'] == 'flux-multi':
169
+ layers = controlnet['layers']
170
+
171
+ for c in controlnet_config.controlnets:
172
+ if c in layers:
173
+ control_mode.append(layers.index(c))
174
+
175
+ return control_mode
176
+
177
+
178
+ # def check_image_safety(images: List[Image.Image]):
179
+ # safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
180
+ # has_nsfw_concepts = safety_checker(
181
+ # images=[images],
182
+ # clip_input=safety_checker_input.pixel_values.to("cuda"),
183
+ # )
184
+
185
+ # return has_nsfw_concepts[1]
186
+
187
+
188
+ def get_prompt_attention(pipeline, prompt, negative_prompt):
189
+ if isinstance(pipeline, flux_pipes):
190
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
191
+ return prompt_embeds, None, pooled_prompt_embeds, None
192
+ elif isinstance(pipeline, sd_pipes):
193
+ prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
194
+ return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
195
+
196
+
197
+ def cleanup(pipeline, loras = None, embeddings = None):
198
+ if loras:
199
+ pipeline.disable_lora()
200
+ pipeline.unload_lora_weights()
201
+ if embeddings:
202
+ pipeline.unload_textual_inversion()
203
+ gc.collect()
204
+ torch.cuda.empty_cache()
205
+
206
+
207
+ # Gen Function
208
+ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
209
+ pipeline_args = get_pipe(request)
210
+ pipeline = pipeline_args["pipeline"]
211
+ try:
212
+ positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
213
+
214
+ # Common Args
215
+ args = {
216
+ 'prompt_embeds': positive_prompt_embeds,
217
+ 'pooled_prompt_embeds': positive_prompt_pooled,
218
+ 'height': request.height,
219
+ 'width': request.width,
220
+ 'num_images_per_prompt': request.num_images_per_prompt,
221
+ 'num_inference_steps': request.num_inference_steps,
222
+ 'guidance_scale': request.guidance_scale,
223
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
224
+ }
225
+
226
+ if isinstance(pipeline, sd_pipes):
227
+ args['clip_skip'] = request.clip_skip
228
+ args['negative_prompt_embeds'] = negative_prompt_embeds
229
+ args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
230
+
231
+ if request.controlnet_config:
232
+ args['control_images'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode)
233
+ args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
234
+
235
+ if request.controlnet_config and isinstance(pipeline, flux_pipes):
236
+ args['control_mode'] = get_control_mode(request.controlnet_config)
237
+
238
+ if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
239
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
240
+ args['strength'] = request.strength
241
+
242
+ if isinstance(request, BaseInpaintReq):
243
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
244
+
245
+ # Generate
246
+ images = pipeline(**args).images
247
+
248
+ # Refiner
249
+ if request.refiner:
250
+ images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
251
+
252
+ return images
253
+ except Exception as e:
254
+ cleanup(pipeline, request.loras, request.embeddings)
255
+ raise gr.Error(f"Error: {e}")
256
+ finally:
257
+ cleanup(pipeline, request.loras, request.embeddings)
tabs/images/load_models.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (
3
+ AutoPipelineForText2Image,
4
+ AutoencoderKL,
5
+ FluxControlNetModel,
6
+ FluxMultiControlNetModel,
7
+ )
8
+ from diffusers.schedulers import *
9
+
10
+ from config import Config
11
+
12
+ def init_sys():
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ models = Config.IMAGES_MODELS
16
+
17
+ for model in models:
18
+ try:
19
+ model['pipeline'] = AutoPipelineForText2Image.from_pretrained(
20
+ model['repo_id'],
21
+ vae=None,
22
+ torch_dtype=model['compute_type'],
23
+ safety_checker=None,
24
+ variant="fp16"
25
+ ).to(device)
26
+ except:
27
+ model['pipeline'] = AutoPipelineForText2Image.from_pretrained(
28
+ model['repo_id'],
29
+ vae=None,
30
+ torch_dtype=model['compute_type'],
31
+ safety_checker=None
32
+ ).to(device)
33
+ model['pipeline'].enable_model_cpu_offload()
34
+
35
+ # VAE n Refiner
36
+ flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
37
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
38
+ refiner = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16).to(device)
39
+
40
+ # ControlNet
41
+ controlnets = Config.IMAGES_CONTROLNETS
42
+ for controlnet in controlnets:
43
+ if controlnet['loader'] == 'flux-multi':
44
+ controlnet['controlnet'] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
45
+ controlnet['repo_id'],
46
+ torch_dtype=controlnet['compute_type']
47
+ ).to(device)])
48
+ elif controlnet['loader'] == 'sdxl':
49
+ controlnet['controlnet'] = FluxControlNetModel.from_pretrained(
50
+ controlnet['repo_id'],
51
+ torch_dtype=controlnet['compute_type']
52
+ ).to(device)
53
+ elif controlnet['loader'] == 'flux':
54
+ controlnet['controlnet'] = FluxControlNetModel.from_pretrained(
55
+ controlnet['repo_id'],
56
+ torch_dtype=controlnet['compute_type']
57
+ ).to(device)
58
+
59
+ return device, models, flux_vae, sdxl_vae, refiner, controlnets
60
+
61
+ device, models, flux_vae, sdxl_vae, refiner, controlnets = init_sys()
tabs/images/models.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict, Any
2
+
3
+ from pydantic import BaseModel, field_validator
4
+ from PIL import Image
5
+
6
+ from config import Config as appConfig
7
+
8
+
9
+ class ControlNetReq(BaseModel):
10
+ controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
11
+ control_images: List[Image.Image]
12
+ controlnet_conditioning_scale: List[float]
13
+
14
+ class Config:
15
+ arbitrary_types_allowed=True
16
+
17
+
18
+ class BaseReq(BaseModel):
19
+ model: str = ""
20
+ prompt: str = ""
21
+ negative_prompt: Optional[str] = None
22
+ fast_generation: Optional[bool] = True
23
+ loras: Optional[list] = []
24
+ embeddings: Optional[list] = None
25
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
26
+ scheduler: Optional[str] = "euler_fl"
27
+ height: int = 1024
28
+ width: int = 1024
29
+ num_images_per_prompt: int = 1
30
+ num_inference_steps: int = 8
31
+ clip_skip: Optional[int] = None
32
+ guidance_scale: float = 3.5
33
+ seed: Optional[int] = 0
34
+ refiner: bool = False
35
+ vae: bool = True
36
+ controlnet_config: Optional[ControlNetReq] = None
37
+ custom_addons: Optional[Dict[Any, Any]] = None
38
+
39
+ class Config:
40
+ arbitrary_types_allowed=True
41
+
42
+ @field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config')
43
+ def check_model(cls, values):
44
+ for m in appConfig.IMAGES_MODELS:
45
+ if m.get('repo_id') == values.get('model'):
46
+ loader = m.get('loader')
47
+
48
+ if loader == "flux" and values.get('negative_prompt'):
49
+ raise ValueError("Negative prompt is not supported for Flux models.")
50
+ if loader == "flux" and values.get('embeddings'):
51
+ raise ValueError("Embeddings are not supported for Flux models.")
52
+ if loader == "flux" and values.get('clip_skip'):
53
+ raise ValueError("Clip skip is not supported for Flux models.")
54
+ if loader == "flux" and values.get('controlnet_config'):
55
+ if "scribble" in values.get('controlnet_config').controlnets:
56
+ raise ValueError("Scribble is not supported for Flux models.")
57
+ return values
58
+
59
+
60
+ class BaseImg2ImgReq(BaseReq):
61
+ image: Image.Image
62
+ strength: float = 1.0
63
+
64
+ class Config:
65
+ arbitrary_types_allowed=True
66
+
67
+
68
+ class BaseInpaintReq(BaseImg2ImgReq):
69
+ mask_image: Image.Image
70
+
71
+ class Config:
72
+ arbitrary_types_allowed=True
tabs/images/ui.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+
5
+ from config import Config
6
+ from .events import *
7
+
8
+
9
+ def image_tab():
10
+ with gr.Row():
11
+ with gr.Column():
12
+ with gr.Group():
13
+ model = gr.Dropdown(label='Model', choices=[model['repo_id'] for model in Config.IMAGES_MODELS], value=Config.IMAGES_MODELS[0]['repo_id'], interactive=True)
14
+ prompt = gr.Textbox(lines=5, label='Prompt', placeholder='Enter your prompt here...', value='A beautiful sunset over the mountains.')
15
+ negative_prompt = gr.Textbox(lines=2, label='Negative Prompt', placeholder='Enter your negative prompt here...', visible=False)
16
+ fast_generation = gr.Checkbox(label='Fast Generation (Hyper-SD πŸ§ͺ)', value=False)
17
+
18
+
19
+ with gr.Accordion('Loras', open=True):
20
+ for m in Config.IMAGES_MODELS:
21
+ if m['repo_id'] == model.value:
22
+ lora_gallery_values = []
23
+ if m['loader'] == 'flux':
24
+ lora_gallery_values = [(lora['image'], lora['title']) for lora in Config.IMAGES_LORAS_FLUX]
25
+ elif m['loader'] == 'sdxl':
26
+ lora_gallery_values = [(lora['image'], lora['title']) for lora in Config.IMAGES_LORAS_SDXL]
27
+
28
+ lora_gallery = gr.Gallery(
29
+ label='Loras',
30
+ value=lora_gallery_values,
31
+ allow_preview=False,
32
+ interactive=True,
33
+ rows=2,
34
+ columns=3,
35
+ )
36
+
37
+ with gr.Group():
38
+ with gr.Column():
39
+ with gr.Row():
40
+ custom_lora = gr.Textbox(label='Custom Lora', info='Enter a Huggingface repo path')
41
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
42
+
43
+ custom_lora_info = gr.HTML(visible=False)
44
+ add_lora = gr.Button(value="Add LoRA")
45
+
46
+ enabled_loras = gr.State(value=[])
47
+ with gr.Group():
48
+ with gr.Row():
49
+ for i in range(6): # only support max 6 loras due to inference time
50
+ with gr.Column():
51
+ with gr.Column(scale=2):
52
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
53
+ with gr.Column():
54
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
55
+
56
+
57
+ with gr.Accordion("Embeddings", open=False) as embeddings_accordion:
58
+ with gr.Group():
59
+ with gr.Row():
60
+ with gr.Group():
61
+ custom_embedding = gr.Textbox(label="Custom Embedding", info="Enter a Huggingface repo path")
62
+ add_embedding = gr.Button(value="Add Embedding")
63
+ custom_embedding_info = gr.HTML(visible=False)
64
+ with gr.Row():
65
+ enabled_embeddings = gr.State(value=[])
66
+ enabled_embeddings_list = gr.Checkboxgroup(label="Enabled Embeddings", choices=[], visible=False)
67
+
68
+
69
+ with gr.Accordion('Image Options', open=False):
70
+
71
+ with gr.Tabs():
72
+ image_options = [
73
+ ('img2img', 'Image to Image', 'image', True),
74
+ ('inpaint', 'Inpainting', 'imageeditor', True),
75
+ ('canny', 'Edge Detection', 'imageeditor', True),
76
+ ('pose', 'Pose Detection', 'imageeditor', True),
77
+ ('depth', 'Depth Estimation', 'imageeditor', True),
78
+ ('scribble', 'Scribble', 'imageeditor', False),
79
+ ]
80
+ for image_option, label, type, visible in image_options:
81
+ with gr.Tab(label=image_option) as globals()[f"{image_option}_tab"]:
82
+ if type == 'image':
83
+ globals()[f"{image_option}_image"] = gr.Image(label=label, visible=visible, interactive=True, type='pil')
84
+ elif type == 'imageeditor':
85
+ globals()[f"{image_option}_image"] = gr.ImageEditor(label=label, visible=visible, interactive=True,
86
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
87
+ type='pil', image_mode='RGB', layers=False)
88
+
89
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
90
+
91
+ resize_mode = gr.Radio(
92
+ label="Resize Mode",
93
+ choices=["crop and resize", "resize only", "resize and fill"],
94
+ value="resize and fill",
95
+ interactive=True
96
+ )
97
+
98
+
99
+ with gr.Column():
100
+ with gr.Group():
101
+ output_images = gr.Gallery(label='Output Image', type='pil', interactive=False, value=[], allow_preview=True)
102
+ generate = gr.Button(value="Generate", variant="primary")
103
+
104
+
105
+ with gr.Accordion('Advance Settings', open=True):
106
+ scheduler = gr.Dropdown(
107
+ label='Scheduler',
108
+ choices = [
109
+ "dpmpp_2m", "dpmpp_2m_k", "dpmpp_2m_sde", "dpmpp_2m_sde_k",
110
+ "dpmpp_sde", "dpmpp_sde_k", "dpm2", "dpm2_k", "dpm2_a",
111
+ "dpm2_a_k", "euler", "euler_a", "heun", "lms", "lms_k",
112
+ "deis", "unipc", "fm_euler"
113
+ ],
114
+ value="fm_euler",
115
+ interactive=True
116
+ )
117
+
118
+ with gr.Row():
119
+ for column in range(2):
120
+ with gr.Column():
121
+ options = [
122
+ ("Height", "image_height", 64, 2048, 64, 1024, True),
123
+ ("Width", "image_width", 64, 2048, 64, 1024, True),
124
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
125
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
126
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, True),
127
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 7.0, True),
128
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
129
+ ]
130
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
131
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
132
+
133
+ with gr.Row():
134
+ refiner = gr.Checkbox(label="Refiner", value=False)
135
+ vae = gr.Checkbox(label="VAE", value=False)
136
+
137
+ # Events
138
+ # Base Options
139
+ model.change(update_model_options, [model], [negative_prompt, lora_gallery, embeddings_accordion, scribble_tab, scheduler, image_clip_skip, image_guidance_scale]) # type: ignore
140
+ fast_generation.change(update_fast_generation, [model, fast_generation], [image_num_inference_steps, image_guidance_scale]) # type: ignore
141
+
142
+ # Loras
143
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
144
+ custom_lora.change(update_selected_lora, custom_lora, [selected_lora, custom_lora_info])
145
+ add_lora.click(add_to_enabled_loras, [selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
146
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
147
+
148
+ for i in range(6):
149
+ globals()[f"lora_remove_{i}"].click(
150
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
151
+ [enabled_loras],
152
+ [enabled_loras]
153
+ )
154
+
155
+ # Embeddings
156
+ custom_embedding.change(update_custom_embedding, custom_embedding, [custom_embedding_info])
157
+ add_embedding.click(add_to_embeddings, [custom_embedding, enabled_embeddings], [custom_embedding, custom_embedding_info, enabled_embeddings])
158
+ for i in range(6):
159
+ globals()[f"embedding_remove_{i}"].click(
160
+ lambda enabled_embeddings, index=i: remove_from_embeddings(enabled_embeddings, index),
161
+ [enabled_embeddings],
162
+ [enabled_embeddings]
163
+ )
164
+
165
+ # Generate Image
166
+ generate.click(
167
+ generate_image, # type: ignore
168
+ [
169
+ model, prompt, negative_prompt, fast_generation, enabled_loras, enabled_embeddings,
170
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
171
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
172
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
173
+ resize_mode,
174
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
175
+ image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
176
+ refiner, vae
177
+ ],
178
+ [output_images]
179
+ )