primerz commited on
Commit
f136417
Β·
verified Β·
1 Parent(s): 0db314e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +549 -109
app.py CHANGED
@@ -1,166 +1,606 @@
1
- import os
2
- import re
 
 
 
3
  import time
4
- import json
 
 
 
 
 
 
5
  import copy
 
 
6
  import random
7
- import requests
8
- import torch
9
- import cv2
10
- import numpy as np
11
- import gradio as gr
12
- import spaces
13
- from PIL import Image
14
  from urllib.parse import quote
 
 
 
 
15
 
16
- # Disable Torch JIT compilation for compatibility
17
- torch.jit.script = lambda f: f
18
-
19
- # Model & Utilities
20
- import timm
21
  import diffusers
22
  from diffusers.utils import load_image
23
  from diffusers.models import ControlNetModel
24
  from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
25
- from safetensors.torch import load_file
26
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
 
 
 
27
  from insightface.app import FaceAnalysis
 
28
  from controlnet_aux import ZoeDetector
 
29
  from compel import Compel, ReturnedEmbeddingsType
 
30
  from gradio_imageslider import ImageSlider
31
 
32
- # Custom imports
33
- try:
34
- from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
35
- from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
36
- except ImportError as e:
37
- print(f"Import Error: {e}. Check if modules exist or paths are correct.")
38
- exit()
39
 
40
- # Device setup
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
- # Load LoRA configuration
44
  with open("sdxl_loras.json", "r") as file:
45
- sdxl_loras_raw = json.load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  with open("defaults_data.json", "r") as file:
48
  lora_defaults = json.load(file)
 
49
 
50
- # Download required models
51
- CHECKPOINT_DIR = "/data/checkpoints"
52
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR)
53
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR)
54
- hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR)
55
- hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR)
56
 
57
- # Download Antelopev2 Face Recognition model
58
- antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
59
- print("Antelopev2 Download Path:", antelope_download)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Initialize FaceAnalysis
62
- app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"])
 
63
  app.prepare(ctx_id=0, det_size=(640, 640))
64
 
65
- # Load identity & depth models
66
- face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin")
67
- controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel")
68
 
 
 
69
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
70
- zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
71
-
 
 
 
72
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
 
 
 
 
73
 
74
- # Load main pipeline
75
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
76
- "frankjoshua/albedobaseXL_v21",
77
- vae=vae,
78
- controlnet=[identitynet, zoedepthnet],
79
- torch_dtype=torch.float16
80
- )
81
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
82
  pipe.load_ip_adapter_instantid(face_adapter)
83
  pipe.set_ip_adapter_scale(0.8)
 
 
 
 
 
 
 
 
84
 
85
- # Initialize Compel for text conditioning
86
- compel = Compel(
87
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
88
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
89
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
90
- requires_pooled=[False, True]
91
- )
92
-
93
- # Load ZoeDetector for depth estimation
94
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
 
 
 
95
  zoe.to(device)
96
  pipe.to(device)
97
 
98
- # LoRA Management
99
  last_lora = ""
100
  last_fused = False
 
 
 
 
 
 
 
 
101
 
102
- # --- Utility Functions ---
103
- def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative):
104
- index = selected_state.index
105
- lora_repo = sdxl_loras[index]["repo"]
106
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
107
 
108
  for lora_list in lora_defaults:
109
- if lora_list["model"] == lora_repo:
110
  face_strength = lora_list.get("face_strength", 0.85)
111
  image_strength = lora_list.get("image_strength", 0.15)
112
  weight = lora_list.get("weight", 0.9)
113
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
114
  negative = lora_list.get("negative", "")
115
-
 
 
 
 
 
 
116
  return (
117
- updated_text, gr.update(placeholder="Type a prompt"), face_strength,
118
- image_strength, weight, depth_control_scale, negative, selected_state
 
 
 
 
 
 
119
  )
120
 
121
- def center_crop_image(img):
122
  square_size = min(img.size)
123
- left = (img.width - square_size) // 2
124
- top = (img.height - square_size) // 2
125
- return img.crop((left, top, left + square_size, top + square_size))
126
-
127
- def process_face(image):
128
- face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
129
- face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1]
130
- face_emb = face_info['embedding']
131
- face_kps = draw_kps(image, face_info['kps'])
132
- return face_emb, face_kps
133
-
134
- def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale):
135
- global last_fused, last_lora
136
- if last_lora != repo_name and last_fused:
137
- pipe.unfuse_lora()
138
- pipe.unload_lora_weights()
139
- pipe.load_lora_weights(repo_name)
140
- pipe.fuse_lora(lora_scale)
141
- last_lora, last_fused = repo_name, True
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  conditioning, pooled = compel(prompt)
144
- negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
145
-
146
- images = [face_kps, zoe(face_image).resize(face_kps.size)]
147
- return pipe(
148
- prompt_embeds=conditioning, pooled_prompt_embeds=pooled,
149
- negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled,
150
- width=1024, height=1024, image_embeds=face_emb, image=face_image,
151
- strength=1-image_strength, control_image=images, num_inference_steps=20,
152
- guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # --- UI Setup ---
156
- with gr.Blocks() as demo:
157
- photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300)
158
- gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550)
159
- prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...")
160
- button = gr.Button("Run")
161
- result = ImageSlider(interactive=False, label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- demo.queue()
166
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ torch.jit.script = lambda f: f
5
+ import timm
6
  import time
7
+
8
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
9
+ from safetensors.torch import load_file
10
+ from share_btn import community_icon_html, loading_icon_html, share_js
11
+ from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
12
+
13
+ import lora
14
  import copy
15
+ import json
16
+ import gc
17
  import random
 
 
 
 
 
 
 
18
  from urllib.parse import quote
19
+ import gdown
20
+ import os
21
+ import re
22
+ import requests
23
 
 
 
 
 
 
24
  import diffusers
25
  from diffusers.utils import load_image
26
  from diffusers.models import ControlNetModel
27
  from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
28
+ import cv2
29
+ import torch
30
+ import numpy as np
31
+ from PIL import Image
32
+
33
  from insightface.app import FaceAnalysis
34
+ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
35
  from controlnet_aux import ZoeDetector
36
+
37
  from compel import Compel, ReturnedEmbeddingsType
38
+
39
  from gradio_imageslider import ImageSlider
40
 
 
 
 
 
 
 
 
41
 
42
+ #from gradio_imageslider import ImageSlider
 
43
 
 
44
  with open("sdxl_loras.json", "r") as file:
45
+ data = json.load(file)
46
+ sdxl_loras_raw = [
47
+ {
48
+ "image": item["image"],
49
+ "title": item["title"],
50
+ "repo": item["repo"],
51
+ "trigger_word": item["trigger_word"],
52
+ "weights": item["weights"],
53
+ "is_compatible": item["is_compatible"],
54
+ "is_pivotal": item.get("is_pivotal", False),
55
+ "text_embedding_weights": item.get("text_embedding_weights", None),
56
+ "likes": item.get("likes", 0),
57
+ "downloads": item.get("downloads", 0),
58
+ "is_nc": item.get("is_nc", False),
59
+ "new": item.get("new", False),
60
+ }
61
+ for item in data
62
+ ]
63
 
64
  with open("defaults_data.json", "r") as file:
65
  lora_defaults = json.load(file)
66
+
67
 
68
+ device = "cuda"
 
 
 
 
 
69
 
70
+ state_dicts = {}
71
+
72
+ for item in sdxl_loras_raw:
73
+ saved_name = hf_hub_download(item["repo"], item["weights"])
74
+
75
+ if not saved_name.endswith('.safetensors'):
76
+ state_dict = torch.load(saved_name)
77
+ else:
78
+ state_dict = load_file(saved_name)
79
+
80
+ state_dicts[item["repo"]] = {
81
+ "saved_name": saved_name,
82
+ "state_dict": state_dict
83
+ }
84
+
85
+ sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True]
86
+
87
+ # download models
88
+ hf_hub_download(
89
+ repo_id="InstantX/InstantID",
90
+ filename="ControlNetModel/config.json",
91
+ local_dir="/data/checkpoints",
92
+ )
93
+ hf_hub_download(
94
+ repo_id="InstantX/InstantID",
95
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
96
+ local_dir="/data/checkpoints",
97
+ )
98
+ hf_hub_download(
99
+ repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints"
100
+ )
101
+ hf_hub_download(
102
+ repo_id="latent-consistency/lcm-lora-sdxl",
103
+ filename="pytorch_lora_weights.safetensors",
104
+ local_dir="/data/checkpoints",
105
+ )
106
+ # download antelopev2
107
+ #if not os.path.exists("/data/antelopev2.zip"):
108
+ # gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True)
109
+ # os.system("unzip /data/antelopev2.zip -d /data/models/")
110
 
111
+ antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
112
+ print(antelope_download)
113
+ app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
114
  app.prepare(ctx_id=0, det_size=(640, 640))
115
 
116
+ # prepare models under ./checkpoints
117
+ face_adapter = f'/data/checkpoints/ip-adapter.bin'
118
+ controlnet_path = f'/data/checkpoints/ControlNetModel'
119
 
120
+ # load IdentityNet
121
+ st = time.time()
122
  identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
123
+ zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16)
124
+ et = time.time()
125
+ elapsed_time = et - st
126
+ print('Loading ControlNet took: ', elapsed_time, 'seconds')
127
+ st = time.time()
128
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
129
+ et = time.time()
130
+ elapsed_time = et - st
131
+ print('Loading VAE took: ', elapsed_time, 'seconds')
132
+ st = time.time()
133
 
134
+ #pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("stablediffusionapi/albedobase-xl-v21",
135
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("frankjoshua/albedobaseXL_v21",
136
+ vae=vae,
137
+ controlnet=[identitynet, zoedepthnet],
138
+ torch_dtype=torch.float16)
 
 
139
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
140
  pipe.load_ip_adapter_instantid(face_adapter)
141
  pipe.set_ip_adapter_scale(0.8)
142
+ et = time.time()
143
+ elapsed_time = et - st
144
+ print('Loading pipeline took: ', elapsed_time, 'seconds')
145
+ st = time.time()
146
+ compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
147
+ et = time.time()
148
+ elapsed_time = et - st
149
+ print('Loading Compel took: ', elapsed_time, 'seconds')
150
 
151
+ st = time.time()
 
 
 
 
 
 
 
 
152
  zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
153
+ et = time.time()
154
+ elapsed_time = et - st
155
+ print('Loading Zoe took: ', elapsed_time, 'seconds')
156
  zoe.to(device)
157
  pipe.to(device)
158
 
 
159
  last_lora = ""
160
  last_fused = False
161
+ js = '''
162
+ var button = document.getElementById('button');
163
+ // Add a click event listener to the button
164
+ button.addEventListener('click', function() {
165
+ element.classList.add('selected');
166
+ });
167
+ '''
168
+ lora_archive = "/data"
169
 
170
+ def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False):
171
+ lora_repo = sdxl_loras[selected_state.index]["repo"]
172
+ new_placeholder = "Type a prompt to use your selected LoRA"
173
+ weight_name = sdxl_loras[selected_state.index]["weights"]
174
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }"
175
 
176
  for lora_list in lora_defaults:
177
+ if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]:
178
  face_strength = lora_list.get("face_strength", 0.85)
179
  image_strength = lora_list.get("image_strength", 0.15)
180
  weight = lora_list.get("weight", 0.9)
181
  depth_control_scale = lora_list.get("depth_control_scale", 0.8)
182
  negative = lora_list.get("negative", "")
183
+
184
+ if(is_new):
185
+ if(selected_state.index == 0):
186
+ selected_state.index = -9999
187
+ else:
188
+ selected_state.index *= -1
189
+
190
  return (
191
+ updated_text,
192
+ gr.update(placeholder=new_placeholder),
193
+ face_strength,
194
+ image_strength,
195
+ weight,
196
+ depth_control_scale,
197
+ negative,
198
+ selected_state
199
  )
200
 
201
+ def center_crop_image_as_square(img):
202
  square_size = min(img.size)
203
+
204
+ left = (img.width - square_size) / 2
205
+ top = (img.height - square_size) / 2
206
+ right = (img.width + square_size) / 2
207
+ bottom = (img.height + square_size) / 2
208
+
209
+ img_cropped = img.crop((left, top, right, bottom))
210
+ return img_cropped
211
+
212
+ def check_selected(selected_state, custom_lora):
213
+ if not selected_state and not custom_lora:
214
+ raise gr.Error("You must select a style")
215
+
216
+ def merge_incompatible_lora(full_path_lora, lora_scale):
217
+ for weights_file in [full_path_lora]:
218
+ if ";" in weights_file:
219
+ weights_file, multiplier = weights_file.split(";")
220
+ multiplier = float(multiplier)
221
+ else:
222
+ multiplier = lora_scale
223
 
224
+ lora_model, weights_sd = lora.create_network_from_weights(
225
+ multiplier,
226
+ full_path_lora,
227
+ pipe.vae,
228
+ pipe.text_encoder,
229
+ pipe.unet,
230
+ for_inference=True,
231
+ )
232
+ lora_model.merge_to(
233
+ pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda"
234
+ )
235
+ del weights_sd
236
+ del lora_model
237
+
238
+ @spaces.GPU(duration=80)
239
+ def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st):
240
+ print(loaded_state_dict)
241
+ et = time.time()
242
+ elapsed_time = et - st
243
+ print('Getting into the decorated function took: ', elapsed_time, 'seconds')
244
+ global last_fused, last_lora
245
+ print("Last LoRA: ", last_lora)
246
+ print("Current LoRA: ", repo_name)
247
+ print("Last fused: ", last_fused)
248
+ #prepare face zoe
249
+ st = time.time()
250
+ with torch.no_grad():
251
+ image_zoe = zoe(face_image)
252
+ width, height = face_kps.size
253
+ images = [face_kps, image_zoe.resize((height, width))]
254
+ et = time.time()
255
+ elapsed_time = et - st
256
+ print('Zoe Depth calculations took: ', elapsed_time, 'seconds')
257
+ if last_lora != repo_name:
258
+ if(last_fused):
259
+ st = time.time()
260
+ pipe.unfuse_lora()
261
+ pipe.unload_lora_weights()
262
+ pipe.unload_textual_inversion()
263
+ et = time.time()
264
+ elapsed_time = et - st
265
+ print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds')
266
+ st = time.time()
267
+ pipe.load_lora_weights(loaded_state_dict)
268
+ pipe.fuse_lora(lora_scale)
269
+ et = time.time()
270
+ elapsed_time = et - st
271
+ print('Fuse and load LoRA took: ', elapsed_time, 'seconds')
272
+ last_fused = True
273
+ is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"]
274
+ if(is_pivotal):
275
+ #Add the textual inversion embeddings from pivotal tuning models
276
+ text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"]
277
+ embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model")
278
+ state_dict_embedding = load_file(embedding_path)
279
+ pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
280
+ pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
281
+
282
+ print("Processing prompt...")
283
+ st = time.time()
284
  conditioning, pooled = compel(prompt)
285
+ if(negative):
286
+ negative_conditioning, negative_pooled = compel(negative)
287
+ else:
288
+ negative_conditioning, negative_pooled = None, None
289
+ et = time.time()
290
+ elapsed_time = et - st
291
+ print('Prompt processing took: ', elapsed_time, 'seconds')
292
+ print("Processing image...")
293
+ st = time.time()
294
+ image = pipe(
295
+ prompt_embeds=conditioning,
296
+ pooled_prompt_embeds=pooled,
297
+ negative_prompt_embeds=negative_conditioning,
298
+ negative_pooled_prompt_embeds=negative_pooled,
299
+ width=1024,
300
+ height=1024,
301
+ image_embeds=face_emb,
302
+ image=face_image,
303
+ strength=1-image_strength,
304
+ control_image=images,
305
+ num_inference_steps=20,
306
+ guidance_scale = guidance_scale,
307
+ controlnet_conditioning_scale=[face_strength, depth_control_scale],
308
  ).images[0]
309
+ et = time.time()
310
+ elapsed_time = et - st
311
+ print('Image processing took: ', elapsed_time, 'seconds')
312
+ last_lora = repo_name
313
+ return image
314
+
315
+ def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)):
316
+ print("Custom LoRA: ", custom_lora)
317
+ custom_lora_path = custom_lora[0] if custom_lora else None
318
+ selected_state_index = selected_state.index if selected_state else -1
319
+ st = time.time()
320
+ face_image = center_crop_image_as_square(face_image)
321
+ try:
322
+ face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
323
+ face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
324
+ face_emb = face_info['embedding']
325
+ face_kps = draw_kps(face_image, face_info['kps'])
326
+ except:
327
+ raise gr.Error("No face found in your image. Only face images work here. Try again")
328
+ et = time.time()
329
+ elapsed_time = et - st
330
+ print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds')
331
+
332
+ st = time.time()
333
+
334
+ if(custom_lora_path and custom_lora[1]):
335
+ prompt = f"{prompt} {custom_lora[1]}"
336
+ else:
337
+ for lora_list in lora_defaults:
338
+ if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]:
339
+ prompt_full = lora_list.get("prompt", None)
340
+ if(prompt_full):
341
+ prompt = prompt_full.replace("<subject>", prompt)
342
+
343
+ print("Prompt:", prompt)
344
+ if(prompt == ""):
345
+ prompt = "a person"
346
+ print(f"Executing prompt: {prompt}")
347
+ #print("Selected State: ", selected_state_index)
348
+ #print(sdxl_loras[selected_state_index]["repo"])
349
+ if negative == "":
350
+ negative = None
351
+ print("Custom Loaded LoRA: ", custom_lora_path)
352
+ if not selected_state and not custom_lora_path:
353
+ raise gr.Error("You must select a style")
354
+ elif custom_lora_path:
355
+ repo_name = custom_lora_path
356
+ full_path_lora = custom_lora_path
357
+ else:
358
+ repo_name = sdxl_loras[selected_state_index]["repo"]
359
+ weight_name = sdxl_loras[selected_state_index]["weights"]
360
+ full_path_lora = state_dicts[repo_name]["saved_name"]
361
+ print("Full path LoRA ", full_path_lora)
362
+ #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"])
363
+ cross_attention_kwargs = None
364
+ et = time.time()
365
+ elapsed_time = et - st
366
+ print('Small content processing took: ', elapsed_time, 'seconds')
367
+
368
+ st = time.time()
369
+ image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, full_path_lora, lora_scale, sdxl_loras, selected_state_index, st)
370
+ return (face_image, image), gr.update(visible=True)
371
+
372
+ run_lora.zerogpu = True
373
+
374
+ def shuffle_gallery(sdxl_loras):
375
+ random.shuffle(sdxl_loras)
376
+ return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
377
+
378
+ def classify_gallery(sdxl_loras):
379
+ sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True)
380
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
381
+
382
+ def swap_gallery(order, sdxl_loras):
383
+ if(order == "random"):
384
+ return shuffle_gallery(sdxl_loras)
385
+ else:
386
+ return classify_gallery(sdxl_loras)
387
+
388
+ def deselect():
389
+ return gr.Gallery(selected_index=None)
390
 
391
+ def get_huggingface_safetensors(link):
392
+ split_link = link.split("/")
393
+ if(len(split_link) == 2):
394
+ model_card = ModelCard.load(link)
395
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
396
+ trigger_word = model_card.data.get("instance_prompt", "")
397
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
398
+ fs = HfFileSystem()
399
+ try:
400
+ list_of_files = fs.ls(link, detail=False)
401
+ for file in list_of_files:
402
+ if(file.endswith(".safetensors")):
403
+ safetensors_name = file.replace("/", "_")
404
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
405
+ fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}")
406
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
407
+ image_elements = file.split("/")
408
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
409
+ except:
410
+ gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
411
+ raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
412
+ return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
413
 
414
+ def get_civitai_safetensors(link):
415
+ link_split = link.split("civitai.com/")
416
+ pattern = re.compile(r'models\/(\d+)')
417
+ regex_match = pattern.search(link_split[1])
418
+ if(regex_match):
419
+ civitai_model_id = regex_match.group(1)
420
+ else:
421
+ gr.Warning("No CivitAI model id found in your URL")
422
+ raise Exception("No CivitAI model id found in your URL")
423
+ model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}"
424
+ x = requests.get(model_request_url)
425
+ if(x.status_code != 200):
426
+ raise Exception("Invalid CivitAI URL")
427
+ model_data = x.json()
428
+ #if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20):
429
+ # gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
430
+ # raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.")
431
+ if(model_data["type"] != "LORA"):
432
+ gr.Warning("The model isn't tagged at CivitAI as a LoRA")
433
+ raise Exception("The model isn't tagged at CivitAI as a LoRA")
434
+ model_link_download = None
435
+ image_url = None
436
+ trigger_word = ""
437
+ for model in model_data["modelVersions"]:
438
+ if(model["baseModel"] == "SDXL 1.0"):
439
+ model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}"
440
+ safetensors_name = model["files"][0]["name"]
441
+ if(not os.path.exists(f"{lora_archive}/{safetensors_name}")):
442
+ safetensors_file_request = requests.get(model_link_download)
443
+ if(safetensors_file_request.status_code != 200):
444
+ raise Exception("Invalid CivitAI download link")
445
+ with open(f"{lora_archive}/{safetensors_name}", 'wb') as file:
446
+ file.write(safetensors_file_request.content)
447
+ trigger_word = model.get("trainedWords", [""])[0]
448
+ for image in model["images"]:
449
+ if(image["nsfwLevel"] == 1):
450
+ image_url = image["url"]
451
+ break
452
+ break
453
+ if(not model_link_download):
454
+ gr.Warning("We couldn't find a SDXL LoRA on the model you've sent")
455
+ raise Exception("We couldn't find a SDXL LoRA on the model you've sent")
456
+ return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url
457
+
458
+ def check_custom_model(link):
459
+ if(link.startswith("https://")):
460
+ if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
461
+ link_split = link.split("huggingface.co/")
462
+ return get_huggingface_safetensors(link_split[1])
463
+ elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")):
464
+ return get_civitai_safetensors(link)
465
+ else:
466
+ return get_huggingface_safetensors(link)
467
+
468
+ def show_loading_widget():
469
+ return gr.update(visible=True)
470
+
471
+ def load_custom_lora(link):
472
+ if(link):
473
+ try:
474
+ title, path, trigger_word, image = check_custom_model(link)
475
+ card = f'''
476
+ <div class="custom_lora_card">
477
+ <span>Loaded custom LoRA:</span>
478
+ <div class="card_internal">
479
+ <img src="{image}" />
480
+ <div>
481
+ <h3>{title}</h3>
482
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
483
+ </div>
484
+ </div>
485
+ </div>
486
+ '''
487
+ return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}"
488
+ except Exception as e:
489
+ gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content")
490
+ return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
491
+ else:
492
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True)
493
+
494
+ def remove_custom_lora():
495
+ return "", gr.update(visible=False), gr.update(visible=False), None
496
+ with gr.Blocks(css="custom.css") as demo:
497
+ gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
498
+ title = gr.HTML(
499
+ """<h1><img src="https://i.imgur.com/DVoGw04.png">
500
+ <span>Face to All<br><small style="
501
+ font-size: 13px;
502
+ display: block;
503
+ font-weight: normal;
504
+ opacity: 0.75;
505
+ ">🧨 diffusers InstantID + ControlNet<br> inspired by fofr's <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a></small></span></h1>""",
506
+ elem_id="title",
507
+ )
508
+ selected_state = gr.State()
509
+ custom_loaded_lora = gr.State()
510
+ with gr.Row(elem_id="main_app"):
511
+ with gr.Column(scale=4, elem_id="box_column"):
512
+ with gr.Group(elem_id="gallery_box"):
513
+ photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300)
514
+ selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", )
515
+ #order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio")
516
+ #new_gallery = gr.Gallery(
517
+ # label="New LoRAs",
518
+ # elem_id="gallery_new",
519
+ # columns=3,
520
+ # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False)
521
+ gallery = gr.Gallery(
522
+ #value=[(item["image"], item["title"]) for item in sdxl_loras],
523
+ label="Pick a style from the gallery",
524
+ allow_preview=False,
525
+ columns=4,
526
+ elem_id="gallery",
527
+ show_share_button=False,
528
+ height=550
529
+ )
530
+ custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...")
531
+ custom_model_card = gr.HTML(visible=False)
532
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
533
+ with gr.Column(scale=5):
534
+ with gr.Row():
535
+ prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt")
536
+ button = gr.Button("Run", elem_id="run_button")
537
+ result = ImageSlider(
538
+ interactive=False, label="Generated Image", elem_id="result-image", position=0.1
539
+ )
540
+ with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
541
+ community_icon = gr.HTML(community_icon_html)
542
+ loading_icon = gr.HTML(loading_icon_html)
543
+ share_button = gr.Button("Share to community", elem_id="share-btn")
544
+ with gr.Accordion("Advanced options", open=False):
545
+ negative = gr.Textbox(label="Negative Prompt")
546
+ weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight")
547
+ face_strength = gr.Slider(0, 2, value=0.85, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models")
548
+ image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo")
549
+ guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale")
550
+ depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght")
551
+ prompt_title = gr.Markdown(
552
+ value="### Click on a LoRA in the gallery to select it",
553
+ visible=True,
554
+ elem_id="selected_lora",
555
+ )
556
+ #order_gallery.change(
557
+ # fn=swap_gallery,
558
+ # inputs=[order_gallery, gr_sdxl_loras],
559
+ # outputs=[gallery, gr_sdxl_loras],
560
+ # queue=False
561
+ #)
562
+ custom_model.input(
563
+ fn=load_custom_lora,
564
+ inputs=[custom_model],
565
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
566
+ )
567
+ custom_model_button.click(
568
+ fn=remove_custom_lora,
569
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
570
+ )
571
+ gallery.select(
572
+ fn=update_selection,
573
+ inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative],
574
+ outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state],
575
+ show_progress=False
576
+ )
577
+ #new_gallery.select(
578
+ # fn=update_selection,
579
+ # inputs=[gr_sdxl_loras_new, gr.State(True)],
580
+ # outputs=[prompt_title, prompt, prompt, selected_state, gallery],
581
+ # queue=False,
582
+ # show_progress=False
583
+ #)
584
+ prompt.submit(
585
+ fn=check_selected,
586
+ inputs=[selected_state, custom_loaded_lora],
587
+ show_progress=False
588
+ ).success(
589
+ fn=run_lora,
590
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
591
+ outputs=[result, share_group],
592
+ )
593
+ button.click(
594
+ fn=check_selected,
595
+ inputs=[selected_state, custom_loaded_lora],
596
+ show_progress=False
597
+ ).success(
598
+ fn=run_lora,
599
+ inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora],
600
+ outputs=[result, share_group],
601
+ )
602
+ share_button.click(None, [], [], js=share_js)
603
+ demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js)
604
 
605
+ demo.queue(default_concurrency_limit=None, api_open=True)
606
+ demo.launch(share=True)