Update inference_manager.py
Browse files- inference_manager.py +212 -18
inference_manager.py
CHANGED
@@ -9,16 +9,19 @@ from huggingface_hub import hf_hub_download, snapshot_download
|
|
9 |
from pathlib import Path
|
10 |
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
|
11 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
12 |
-
import os
|
13 |
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
14 |
from cryptography.hazmat.primitives import serialization, hashes
|
15 |
from cryptography.hazmat.backends import default_backend
|
16 |
from cryptography.hazmat.primitives.asymmetric import utils
|
17 |
import base64
|
18 |
import json
|
|
|
19 |
import jwt
|
20 |
import glob
|
21 |
import traceback
|
|
|
|
|
|
|
22 |
|
23 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
24 |
|
@@ -66,37 +69,57 @@ class AuthHelper:
|
|
66 |
print("Invalid token:", e)
|
67 |
raise
|
68 |
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
if params.get("_skip_token_passkey", "") == "nsfwaisio_125687":
|
72 |
return True
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
jwt_data = self.decode_jwt(token)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
return True
|
86 |
-
|
|
|
87 |
|
88 |
class InferenceManager:
|
89 |
-
def __init__(self, config_path="config.json"):
|
90 |
cfg = {}
|
91 |
with open(config_path, "r", encoding="utf-8") as f:
|
92 |
cfg = json.load(f)
|
93 |
self.cfg = cfg
|
|
|
|
|
94 |
lora_options_path = cfg.get("loras", "")
|
95 |
self.model_version = cfg["model_version"]
|
96 |
self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
|
97 |
self.lora_models = self.load_index_file("index.json") # Load index.json
|
98 |
self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
|
|
|
99 |
self.base_model_pipeline = self.load_base_model() # Load the base model
|
|
|
100 |
self.preload_loras() # Preload LoRAs based on options
|
101 |
|
102 |
def load_json(self, filepath):
|
@@ -165,6 +188,7 @@ class InferenceManager:
|
|
165 |
#unet=unet,
|
166 |
torch_dtype=torch.bfloat16,
|
167 |
use_safetensors=True,
|
|
|
168 |
#variant="fp16",
|
169 |
custom_pipeline = "lpw_stable_diffusion_xl",
|
170 |
)
|
@@ -175,8 +199,19 @@ class InferenceManager:
|
|
175 |
|
176 |
load_time = round(time.time() - start, 2)
|
177 |
print(f"Base model loaded in {load_time}s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
return pipe
|
179 |
|
|
|
180 |
def preload_loras(self):
|
181 |
"""Preload all LoRAs marked as 'preload=True' and store for later use."""
|
182 |
for lora_name, lora_info in self.lora_load_options.items():
|
@@ -279,9 +314,36 @@ class ModelManager:
|
|
279 |
|
280 |
:param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
|
281 |
"""
|
|
|
|
|
|
|
|
|
|
|
282 |
self.models = {}
|
|
|
283 |
self.model_directory = model_directory
|
284 |
self.load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
def load_models(self):
|
287 |
"""
|
@@ -299,7 +361,7 @@ class ModelManager:
|
|
299 |
print(f"Initializing model: {model_name} from {file_path}")
|
300 |
try:
|
301 |
# Initialize InferenceManager for each model
|
302 |
-
self.models[model_name] = InferenceManager(config_path=file_path)
|
303 |
except Exception as e:
|
304 |
print(traceback.format_exc())
|
305 |
print(f"Failed to initialize model {model_name} from {file_path}: {e}")
|
@@ -352,9 +414,141 @@ class ModelManager:
|
|
352 |
model.release(model.base_model_pipeline)
|
353 |
except Exception as e:
|
354 |
print(f"Failed to release model {model_id}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
|
356 |
# Hugging Face file download function - returns only file path
|
357 |
-
def download_from_hf(filename, local_dir=None):
|
358 |
try:
|
359 |
file_path = hf_hub_download(
|
360 |
filename=filename,
|
|
|
9 |
from pathlib import Path
|
10 |
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
|
11 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
|
|
12 |
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
13 |
from cryptography.hazmat.primitives import serialization, hashes
|
14 |
from cryptography.hazmat.backends import default_backend
|
15 |
from cryptography.hazmat.primitives.asymmetric import utils
|
16 |
import base64
|
17 |
import json
|
18 |
+
import ipown
|
19 |
import jwt
|
20 |
import glob
|
21 |
import traceback
|
22 |
+
from insightface.app import FaceAnalysis
|
23 |
+
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
24 |
+
import cv2
|
25 |
|
26 |
#from onediffx import compile_pipe, save_pipe, load_pipe
|
27 |
|
|
|
69 |
print("Invalid token:", e)
|
70 |
raise
|
71 |
|
72 |
+
import hashlib
|
73 |
+
|
74 |
+
def check_auth(self, request, token):
|
75 |
+
# Extract parameters from the request
|
76 |
+
params = dict(request.query_params)
|
77 |
if params.get("_skip_token_passkey", "") == "nsfwaisio_125687":
|
78 |
return True
|
79 |
+
|
80 |
+
# Gather request-specific information
|
81 |
+
sip = request.client.host
|
82 |
+
shost = request.headers.get("Host", "")
|
83 |
+
sreferer = request.headers.get("Referer", "")
|
84 |
+
suseragent = request.headers.get("User-Agent", "")
|
85 |
+
|
86 |
+
print(sip, shost, sreferer, suseragent)
|
87 |
+
|
88 |
+
# Decode the JWT token
|
89 |
jwt_data = self.decode_jwt(token)
|
90 |
+
jwt_auth = jwt_data.get("auth", "")
|
91 |
+
|
92 |
+
if not jwt_auth:
|
93 |
+
raise Exception("Missing auth field in token")
|
94 |
+
|
95 |
+
# Create the MD5 hash of ip + host + referer + useragent
|
96 |
+
auth_string = f"{sip}{shost}{sreferer}{suseragent}"
|
97 |
+
calculated_md5 = hashlib.md5(auth_string.encode('utf-8')).hexdigest()
|
98 |
+
|
99 |
+
print(f"Calculated MD5: {calculated_md5}, JWT Auth: {jwt_auth}")
|
100 |
+
|
101 |
+
# Compare the calculated hash with the `auth` field from the JWT
|
102 |
+
if calculated_md5 == jwt_auth:
|
103 |
return True
|
104 |
+
|
105 |
+
raise Exception("Invalid authentication")
|
106 |
|
107 |
class InferenceManager:
|
108 |
+
def __init__(self, config_path="config.json", ext_model_pathes={}):
|
109 |
cfg = {}
|
110 |
with open(config_path, "r", encoding="utf-8") as f:
|
111 |
cfg = json.load(f)
|
112 |
self.cfg = cfg
|
113 |
+
self.ext_model_pathes = ext_model_pathes
|
114 |
+
|
115 |
lora_options_path = cfg.get("loras", "")
|
116 |
self.model_version = cfg["model_version"]
|
117 |
self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
|
118 |
self.lora_models = self.load_index_file("index.json") # Load index.json
|
119 |
self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
|
120 |
+
self.ip_adapter_faceid_pipeline = None
|
121 |
self.base_model_pipeline = self.load_base_model() # Load the base model
|
122 |
+
|
123 |
self.preload_loras() # Preload LoRAs based on options
|
124 |
|
125 |
def load_json(self, filepath):
|
|
|
188 |
#unet=unet,
|
189 |
torch_dtype=torch.bfloat16,
|
190 |
use_safetensors=True,
|
191 |
+
sampler=cfg.get("sampler"),
|
192 |
#variant="fp16",
|
193 |
custom_pipeline = "lpw_stable_diffusion_xl",
|
194 |
)
|
|
|
199 |
|
200 |
load_time = round(time.time() - start, 2)
|
201 |
print(f"Base model loaded in {load_time}s")
|
202 |
+
|
203 |
+
if cfg.get("load_ip_adapter_faceid", False):
|
204 |
+
if model_version in ("pony", "xl"):
|
205 |
+
ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
|
206 |
+
if ip_ckpt:
|
207 |
+
print(f"loading ip adapter model for {model_name}")
|
208 |
+
self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda')
|
209 |
+
else:
|
210 |
+
print("ip-adapter-faceid-sdxl not found, skip")
|
211 |
+
|
212 |
return pipe
|
213 |
|
214 |
+
|
215 |
def preload_loras(self):
|
216 |
"""Preload all LoRAs marked as 'preload=True' and store for later use."""
|
217 |
for lora_name, lora_info in self.lora_load_options.items():
|
|
|
314 |
|
315 |
:param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
|
316 |
"""
|
317 |
+
print("downloading models...")
|
318 |
+
self.ext_model_pathes = {
|
319 |
+
"ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
|
320 |
+
}
|
321 |
+
|
322 |
self.models = {}
|
323 |
+
self.ext_models = {}
|
324 |
self.model_directory = model_directory
|
325 |
self.load_models()
|
326 |
+
|
327 |
+
#not enabled at the moment
|
328 |
+
def load_instant_x(self):
|
329 |
+
#load all models
|
330 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
|
331 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
|
332 |
+
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
|
333 |
+
os.makedirs("./models",exist_ok=True)
|
334 |
+
download_from_hf("models/antelopev2/1k3d68.onnx",local_dir="./models")
|
335 |
+
download_from_hf("models/antelopev2/2d106det.onnx",local_dir="./models")
|
336 |
+
download_from_hf("models/antelopev2/genderage.onnx",local_dir="./models")
|
337 |
+
download_from_hf("models/antelopev2/glintr100.onnx",local_dir="./models")
|
338 |
+
download_from_hf("models/antelopev2/scrfd_10g_bnkps.onnx",local_dir="./models")
|
339 |
+
|
340 |
+
# prepare 'antelopev2' under ./models
|
341 |
+
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
342 |
+
app.prepare(ctx_id=0, det_size=(640, 640))
|
343 |
+
|
344 |
+
# prepare models under ./checkpoints
|
345 |
+
face_adapter = f'./checkpoints/ip-adapter.bin'
|
346 |
+
controlnet_path = f'./checkpoints/ControlNetModel'
|
347 |
|
348 |
def load_models(self):
|
349 |
"""
|
|
|
361 |
print(f"Initializing model: {model_name} from {file_path}")
|
362 |
try:
|
363 |
# Initialize InferenceManager for each model
|
364 |
+
self.models[model_name] = InferenceManager(config_path=file_path, ext_model_pathes=self.ext_model_pathes)
|
365 |
except Exception as e:
|
366 |
print(traceback.format_exc())
|
367 |
print(f"Failed to initialize model {model_name} from {file_path}: {e}")
|
|
|
414 |
model.release(model.base_model_pipeline)
|
415 |
except Exception as e:
|
416 |
print(f"Failed to release model {model_id}: {e}")
|
417 |
+
|
418 |
+
@spaces.GPU(duration=40)
|
419 |
+
def generate_with_faceid(self, model_id, request, inference_params, progress=gr.Progress(track_tqdm=True)):
|
420 |
+
auth_helper.check_auth(request, token)
|
421 |
+
model = self.models.get(model_id)
|
422 |
+
if not model:
|
423 |
+
raise Exception(f"invalid model_id {model_id}")
|
424 |
+
if not model.ip_adapter_faceid_pipeline:
|
425 |
+
raise Exception(f"model does not support ip adapter")
|
426 |
+
pipe = model.ip_adapter_faceid_pipeline
|
427 |
+
cfg = model.cfg
|
428 |
+
p = inference_params.get("prompt")
|
429 |
+
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
|
430 |
+
steps = inference_params.get("steps", cfg.get("inference_steps", 30))
|
431 |
+
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
|
432 |
+
width = inference_params.get("width", cfg.get("width", 512))
|
433 |
+
height = inference_params.get("height", cfg.get("height", 512))
|
434 |
+
images = inference_params.get("images", [])
|
435 |
+
likeness_strength = inference_params.get("likeness_strength", 0.4)
|
436 |
+
face_strength = inference_params.get("face_strength", 0.1)
|
437 |
+
sampler = inference_params.get("sampler", cfg.get("sampler", ""))
|
438 |
+
lora_list = inference_params.get("loras", [])
|
439 |
+
|
440 |
+
if not images:
|
441 |
+
raise Exception(f"face images not provided")
|
442 |
+
start = time.time()
|
443 |
+
pipe.to("cuda")
|
444 |
+
print("loading face analysis...")
|
445 |
+
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
446 |
+
app.prepare(ctx_id=0, det_size=(512, 512))
|
447 |
+
|
448 |
+
faceid_all_embeds = []
|
449 |
+
for image in images:
|
450 |
+
face = cv2.imread(image)
|
451 |
+
faces = app.get(face)
|
452 |
+
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
453 |
+
faceid_all_embeds.append(faceid_embed)
|
454 |
+
|
455 |
+
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
|
456 |
+
|
457 |
+
print("start inference...")
|
458 |
+
style_selection = ""
|
459 |
+
use_negative_prompt = True
|
460 |
+
randomize_seed = True
|
461 |
+
seed = seed or int(randomize_seed_fn(seed, randomize_seed))
|
462 |
+
p = remove_child_related_content(p)
|
463 |
+
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
|
464 |
+
generator = torch.Generator(pipe.device).manual_seed(seed)
|
465 |
+
print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
|
466 |
+
images = pipe(
|
467 |
+
prompt=prompt_str,
|
468 |
+
negative_prompt=negative_prompt,
|
469 |
+
faceid_embeds=average_embedding,
|
470 |
+
scale=likeness_strength,
|
471 |
+
width=width,
|
472 |
+
height=height,
|
473 |
+
guidance_scale=face_strength,
|
474 |
+
num_inference_steps=steps,
|
475 |
+
generator=generator,
|
476 |
+
num_images_per_prompt=1,
|
477 |
+
output_type="pil",
|
478 |
+
#callback_on_step_end=callback_dynamic_cfg,
|
479 |
+
#callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
|
480 |
+
).images
|
481 |
+
cost = round(time.time() - start, 2)
|
482 |
+
print(f"inference done in {cost}s")
|
483 |
+
images = [save_image(img) for img in images]
|
484 |
+
image_paths = [i[1] for i in images]
|
485 |
+
print(prompt_str, image_paths)
|
486 |
+
return [i[0] for i in images]
|
487 |
+
|
488 |
+
@spaces.GPU(duration=40)
|
489 |
+
def generate(self, model_id, request, inference_params, progress=gr.Progress(track_tqdm=True)):
|
490 |
+
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
491 |
+
cfg_disabling_at = cfg.get('cfg_disabling_rate', 0.75)
|
492 |
+
if step_index == int(pipe.num_timesteps * cfg_disabling_at):
|
493 |
+
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
|
494 |
+
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
|
495 |
+
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
|
496 |
+
pipe._guidance_scale = 0.0
|
497 |
+
|
498 |
+
return callback_kwargs
|
499 |
+
auth_helper.check_auth(request, token)
|
500 |
+
model = self.models.get(model_id)
|
501 |
+
if not model:
|
502 |
+
raise Exception(f"invalid model_id {model_id}")
|
503 |
+
if not model.ip_adapter_faceid_pipeline:
|
504 |
+
raise Exception(f"model does not support ip adapter")
|
505 |
+
|
506 |
+
cfg = model.cfg
|
507 |
+
p = inference_params.get("prompt")
|
508 |
+
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
|
509 |
+
inference_steps = inference_params.get("steps", cfg.get("inference_steps", 30))
|
510 |
+
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
|
511 |
+
width = inference_params.get("width", cfg.get("width", 512))
|
512 |
+
height = inference_params.get("height", cfg.get("height", 512))
|
513 |
+
sampler = inference_params.get("sampler", cfg.get("sampler", ""))
|
514 |
+
lora_list = inference_params.get("loras", [])
|
515 |
+
|
516 |
+
pipe = model.build_pipeline_with_lora(lora_list, sampler, lora_list)
|
517 |
+
|
518 |
+
start = time.time()
|
519 |
+
pipe.to("cuda")
|
520 |
+
print("start inference...")
|
521 |
+
style_selection = ""
|
522 |
+
use_negative_prompt = True
|
523 |
+
randomize_seed = True
|
524 |
+
seed = seed or int(randomize_seed_fn(seed, randomize_seed))
|
525 |
+
guidance_scale = guidance_scale or cfg.get("guidance_scale", 7.5)
|
526 |
+
p = remove_child_related_content(p)
|
527 |
+
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
|
528 |
+
generator = torch.Generator(pipe.device).manual_seed(seed)
|
529 |
+
print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
|
530 |
+
images = pipe(
|
531 |
+
prompt=prompt_str,
|
532 |
+
negative_prompt=negative_prompt,
|
533 |
+
width=width,
|
534 |
+
height=height,
|
535 |
+
guidance_scale=guidance_scale,
|
536 |
+
num_inference_steps=inference_steps,
|
537 |
+
generator=generator,
|
538 |
+
num_images_per_prompt=1,
|
539 |
+
output_type="pil",
|
540 |
+
callback_on_step_end=callback_dynamic_cfg,
|
541 |
+
callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
|
542 |
+
).images
|
543 |
+
cost = round(time.time() - start, 2)
|
544 |
+
print(f"inference done in {cost}s")
|
545 |
+
images = [save_image(img) for img in images]
|
546 |
+
image_paths = [i[1] for i in images]
|
547 |
+
print(prompt_str, image_paths)
|
548 |
+
return [i[0] for i in images]
|
549 |
|
550 |
# Hugging Face file download function - returns only file path
|
551 |
+
def download_from_hf(filename, local_dir=None, repo_id=DATASET_ID, repo_type="dataset"):
|
552 |
try:
|
553 |
file_path = hf_hub_download(
|
554 |
filename=filename,
|