nsfwalex commited on
Commit
3f6dbd6
·
verified ·
1 Parent(s): 3797680

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +212 -18
inference_manager.py CHANGED
@@ -9,16 +9,19 @@ from huggingface_hub import hf_hub_download, snapshot_download
9
  from pathlib import Path
10
  from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
11
  from diffusers.models.attention_processor import AttnProcessor2_0
12
- import os
13
  from cryptography.hazmat.primitives.asymmetric import rsa, padding
14
  from cryptography.hazmat.primitives import serialization, hashes
15
  from cryptography.hazmat.backends import default_backend
16
  from cryptography.hazmat.primitives.asymmetric import utils
17
  import base64
18
  import json
 
19
  import jwt
20
  import glob
21
  import traceback
 
 
 
22
 
23
  #from onediffx import compile_pipe, save_pipe, load_pipe
24
 
@@ -66,37 +69,57 @@ class AuthHelper:
66
  print("Invalid token:", e)
67
  raise
68
 
69
- def check_auth(self, session, token):
70
- params = session.get("params") or {}
 
 
 
71
  if params.get("_skip_token_passkey", "") == "nsfwaisio_125687":
72
  return True
73
- sip = session.get("client_ip", "")
74
- shost = session.get("host", "")
75
- sreferer = session.get("refer")
76
- print(sip, shost, sreferer)
 
 
 
 
 
 
77
  jwt_data = self.decode_jwt(token)
78
- tip = jwt_data.get("ip", "")
79
- thost = jwt_data.get("host", "")
80
- treferer = jwt_data.get("referer", "")
81
- print(sip, tip, shost, thost, sreferer, treferer)
82
- if not tip or not thost or not treferer:
83
- raise Exception("invalid token")
84
- if sip == tip and shost == thost and sreferer == treferer:
 
 
 
 
 
 
85
  return True
86
- raise Exception("wrong token")
 
87
 
88
  class InferenceManager:
89
- def __init__(self, config_path="config.json"):
90
  cfg = {}
91
  with open(config_path, "r", encoding="utf-8") as f:
92
  cfg = json.load(f)
93
  self.cfg = cfg
 
 
94
  lora_options_path = cfg.get("loras", "")
95
  self.model_version = cfg["model_version"]
96
  self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
97
  self.lora_models = self.load_index_file("index.json") # Load index.json
98
  self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
 
99
  self.base_model_pipeline = self.load_base_model() # Load the base model
 
100
  self.preload_loras() # Preload LoRAs based on options
101
 
102
  def load_json(self, filepath):
@@ -165,6 +188,7 @@ class InferenceManager:
165
  #unet=unet,
166
  torch_dtype=torch.bfloat16,
167
  use_safetensors=True,
 
168
  #variant="fp16",
169
  custom_pipeline = "lpw_stable_diffusion_xl",
170
  )
@@ -175,8 +199,19 @@ class InferenceManager:
175
 
176
  load_time = round(time.time() - start, 2)
177
  print(f"Base model loaded in {load_time}s")
 
 
 
 
 
 
 
 
 
 
178
  return pipe
179
 
 
180
  def preload_loras(self):
181
  """Preload all LoRAs marked as 'preload=True' and store for later use."""
182
  for lora_name, lora_info in self.lora_load_options.items():
@@ -279,9 +314,36 @@ class ModelManager:
279
 
280
  :param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
281
  """
 
 
 
 
 
282
  self.models = {}
 
283
  self.model_directory = model_directory
284
  self.load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  def load_models(self):
287
  """
@@ -299,7 +361,7 @@ class ModelManager:
299
  print(f"Initializing model: {model_name} from {file_path}")
300
  try:
301
  # Initialize InferenceManager for each model
302
- self.models[model_name] = InferenceManager(config_path=file_path)
303
  except Exception as e:
304
  print(traceback.format_exc())
305
  print(f"Failed to initialize model {model_name} from {file_path}: {e}")
@@ -352,9 +414,141 @@ class ModelManager:
352
  model.release(model.base_model_pipeline)
353
  except Exception as e:
354
  print(f"Failed to release model {model_id}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # Hugging Face file download function - returns only file path
357
- def download_from_hf(filename, local_dir=None):
358
  try:
359
  file_path = hf_hub_download(
360
  filename=filename,
 
9
  from pathlib import Path
10
  from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
11
  from diffusers.models.attention_processor import AttnProcessor2_0
 
12
  from cryptography.hazmat.primitives.asymmetric import rsa, padding
13
  from cryptography.hazmat.primitives import serialization, hashes
14
  from cryptography.hazmat.backends import default_backend
15
  from cryptography.hazmat.primitives.asymmetric import utils
16
  import base64
17
  import json
18
+ import ipown
19
  import jwt
20
  import glob
21
  import traceback
22
+ from insightface.app import FaceAnalysis
23
+ from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
24
+ import cv2
25
 
26
  #from onediffx import compile_pipe, save_pipe, load_pipe
27
 
 
69
  print("Invalid token:", e)
70
  raise
71
 
72
+ import hashlib
73
+
74
+ def check_auth(self, request, token):
75
+ # Extract parameters from the request
76
+ params = dict(request.query_params)
77
  if params.get("_skip_token_passkey", "") == "nsfwaisio_125687":
78
  return True
79
+
80
+ # Gather request-specific information
81
+ sip = request.client.host
82
+ shost = request.headers.get("Host", "")
83
+ sreferer = request.headers.get("Referer", "")
84
+ suseragent = request.headers.get("User-Agent", "")
85
+
86
+ print(sip, shost, sreferer, suseragent)
87
+
88
+ # Decode the JWT token
89
  jwt_data = self.decode_jwt(token)
90
+ jwt_auth = jwt_data.get("auth", "")
91
+
92
+ if not jwt_auth:
93
+ raise Exception("Missing auth field in token")
94
+
95
+ # Create the MD5 hash of ip + host + referer + useragent
96
+ auth_string = f"{sip}{shost}{sreferer}{suseragent}"
97
+ calculated_md5 = hashlib.md5(auth_string.encode('utf-8')).hexdigest()
98
+
99
+ print(f"Calculated MD5: {calculated_md5}, JWT Auth: {jwt_auth}")
100
+
101
+ # Compare the calculated hash with the `auth` field from the JWT
102
+ if calculated_md5 == jwt_auth:
103
  return True
104
+
105
+ raise Exception("Invalid authentication")
106
 
107
  class InferenceManager:
108
+ def __init__(self, config_path="config.json", ext_model_pathes={}):
109
  cfg = {}
110
  with open(config_path, "r", encoding="utf-8") as f:
111
  cfg = json.load(f)
112
  self.cfg = cfg
113
+ self.ext_model_pathes = ext_model_pathes
114
+
115
  lora_options_path = cfg.get("loras", "")
116
  self.model_version = cfg["model_version"]
117
  self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
118
  self.lora_models = self.load_index_file("index.json") # Load index.json
119
  self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
120
+ self.ip_adapter_faceid_pipeline = None
121
  self.base_model_pipeline = self.load_base_model() # Load the base model
122
+
123
  self.preload_loras() # Preload LoRAs based on options
124
 
125
  def load_json(self, filepath):
 
188
  #unet=unet,
189
  torch_dtype=torch.bfloat16,
190
  use_safetensors=True,
191
+ sampler=cfg.get("sampler"),
192
  #variant="fp16",
193
  custom_pipeline = "lpw_stable_diffusion_xl",
194
  )
 
199
 
200
  load_time = round(time.time() - start, 2)
201
  print(f"Base model loaded in {load_time}s")
202
+
203
+ if cfg.get("load_ip_adapter_faceid", False):
204
+ if model_version in ("pony", "xl"):
205
+ ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
206
+ if ip_ckpt:
207
+ print(f"loading ip adapter model for {model_name}")
208
+ self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda')
209
+ else:
210
+ print("ip-adapter-faceid-sdxl not found, skip")
211
+
212
  return pipe
213
 
214
+
215
  def preload_loras(self):
216
  """Preload all LoRAs marked as 'preload=True' and store for later use."""
217
  for lora_name, lora_info in self.lora_load_options.items():
 
314
 
315
  :param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
316
  """
317
+ print("downloading models...")
318
+ self.ext_model_pathes = {
319
+ "ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
320
+ }
321
+
322
  self.models = {}
323
+ self.ext_models = {}
324
  self.model_directory = model_directory
325
  self.load_models()
326
+
327
+ #not enabled at the moment
328
+ def load_instant_x(self):
329
+ #load all models
330
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
331
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
332
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
333
+ os.makedirs("./models",exist_ok=True)
334
+ download_from_hf("models/antelopev2/1k3d68.onnx",local_dir="./models")
335
+ download_from_hf("models/antelopev2/2d106det.onnx",local_dir="./models")
336
+ download_from_hf("models/antelopev2/genderage.onnx",local_dir="./models")
337
+ download_from_hf("models/antelopev2/glintr100.onnx",local_dir="./models")
338
+ download_from_hf("models/antelopev2/scrfd_10g_bnkps.onnx",local_dir="./models")
339
+
340
+ # prepare 'antelopev2' under ./models
341
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
342
+ app.prepare(ctx_id=0, det_size=(640, 640))
343
+
344
+ # prepare models under ./checkpoints
345
+ face_adapter = f'./checkpoints/ip-adapter.bin'
346
+ controlnet_path = f'./checkpoints/ControlNetModel'
347
 
348
  def load_models(self):
349
  """
 
361
  print(f"Initializing model: {model_name} from {file_path}")
362
  try:
363
  # Initialize InferenceManager for each model
364
+ self.models[model_name] = InferenceManager(config_path=file_path, ext_model_pathes=self.ext_model_pathes)
365
  except Exception as e:
366
  print(traceback.format_exc())
367
  print(f"Failed to initialize model {model_name} from {file_path}: {e}")
 
414
  model.release(model.base_model_pipeline)
415
  except Exception as e:
416
  print(f"Failed to release model {model_id}: {e}")
417
+
418
+ @spaces.GPU(duration=40)
419
+ def generate_with_faceid(self, model_id, request, inference_params, progress=gr.Progress(track_tqdm=True)):
420
+ auth_helper.check_auth(request, token)
421
+ model = self.models.get(model_id)
422
+ if not model:
423
+ raise Exception(f"invalid model_id {model_id}")
424
+ if not model.ip_adapter_faceid_pipeline:
425
+ raise Exception(f"model does not support ip adapter")
426
+ pipe = model.ip_adapter_faceid_pipeline
427
+ cfg = model.cfg
428
+ p = inference_params.get("prompt")
429
+ negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
430
+ steps = inference_params.get("steps", cfg.get("inference_steps", 30))
431
+ guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
432
+ width = inference_params.get("width", cfg.get("width", 512))
433
+ height = inference_params.get("height", cfg.get("height", 512))
434
+ images = inference_params.get("images", [])
435
+ likeness_strength = inference_params.get("likeness_strength", 0.4)
436
+ face_strength = inference_params.get("face_strength", 0.1)
437
+ sampler = inference_params.get("sampler", cfg.get("sampler", ""))
438
+ lora_list = inference_params.get("loras", [])
439
+
440
+ if not images:
441
+ raise Exception(f"face images not provided")
442
+ start = time.time()
443
+ pipe.to("cuda")
444
+ print("loading face analysis...")
445
+ app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
446
+ app.prepare(ctx_id=0, det_size=(512, 512))
447
+
448
+ faceid_all_embeds = []
449
+ for image in images:
450
+ face = cv2.imread(image)
451
+ faces = app.get(face)
452
+ faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
453
+ faceid_all_embeds.append(faceid_embed)
454
+
455
+ average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
456
+
457
+ print("start inference...")
458
+ style_selection = ""
459
+ use_negative_prompt = True
460
+ randomize_seed = True
461
+ seed = seed or int(randomize_seed_fn(seed, randomize_seed))
462
+ p = remove_child_related_content(p)
463
+ prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
464
+ generator = torch.Generator(pipe.device).manual_seed(seed)
465
+ print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
466
+ images = pipe(
467
+ prompt=prompt_str,
468
+ negative_prompt=negative_prompt,
469
+ faceid_embeds=average_embedding,
470
+ scale=likeness_strength,
471
+ width=width,
472
+ height=height,
473
+ guidance_scale=face_strength,
474
+ num_inference_steps=steps,
475
+ generator=generator,
476
+ num_images_per_prompt=1,
477
+ output_type="pil",
478
+ #callback_on_step_end=callback_dynamic_cfg,
479
+ #callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
480
+ ).images
481
+ cost = round(time.time() - start, 2)
482
+ print(f"inference done in {cost}s")
483
+ images = [save_image(img) for img in images]
484
+ image_paths = [i[1] for i in images]
485
+ print(prompt_str, image_paths)
486
+ return [i[0] for i in images]
487
+
488
+ @spaces.GPU(duration=40)
489
+ def generate(self, model_id, request, inference_params, progress=gr.Progress(track_tqdm=True)):
490
+ def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
491
+ cfg_disabling_at = cfg.get('cfg_disabling_rate', 0.75)
492
+ if step_index == int(pipe.num_timesteps * cfg_disabling_at):
493
+ callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
494
+ callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
495
+ callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
496
+ pipe._guidance_scale = 0.0
497
+
498
+ return callback_kwargs
499
+ auth_helper.check_auth(request, token)
500
+ model = self.models.get(model_id)
501
+ if not model:
502
+ raise Exception(f"invalid model_id {model_id}")
503
+ if not model.ip_adapter_faceid_pipeline:
504
+ raise Exception(f"model does not support ip adapter")
505
+
506
+ cfg = model.cfg
507
+ p = inference_params.get("prompt")
508
+ negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", ""))
509
+ inference_steps = inference_params.get("steps", cfg.get("inference_steps", 30))
510
+ guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7))
511
+ width = inference_params.get("width", cfg.get("width", 512))
512
+ height = inference_params.get("height", cfg.get("height", 512))
513
+ sampler = inference_params.get("sampler", cfg.get("sampler", ""))
514
+ lora_list = inference_params.get("loras", [])
515
+
516
+ pipe = model.build_pipeline_with_lora(lora_list, sampler, lora_list)
517
+
518
+ start = time.time()
519
+ pipe.to("cuda")
520
+ print("start inference...")
521
+ style_selection = ""
522
+ use_negative_prompt = True
523
+ randomize_seed = True
524
+ seed = seed or int(randomize_seed_fn(seed, randomize_seed))
525
+ guidance_scale = guidance_scale or cfg.get("guidance_scale", 7.5)
526
+ p = remove_child_related_content(p)
527
+ prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p)
528
+ generator = torch.Generator(pipe.device).manual_seed(seed)
529
+ print(f"generate: p={p}, np={np}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}")
530
+ images = pipe(
531
+ prompt=prompt_str,
532
+ negative_prompt=negative_prompt,
533
+ width=width,
534
+ height=height,
535
+ guidance_scale=guidance_scale,
536
+ num_inference_steps=inference_steps,
537
+ generator=generator,
538
+ num_images_per_prompt=1,
539
+ output_type="pil",
540
+ callback_on_step_end=callback_dynamic_cfg,
541
+ callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
542
+ ).images
543
+ cost = round(time.time() - start, 2)
544
+ print(f"inference done in {cost}s")
545
+ images = [save_image(img) for img in images]
546
+ image_paths = [i[1] for i in images]
547
+ print(prompt_str, image_paths)
548
+ return [i[0] for i in images]
549
 
550
  # Hugging Face file download function - returns only file path
551
+ def download_from_hf(filename, local_dir=None, repo_id=DATASET_ID, repo_type="dataset"):
552
  try:
553
  file_path = hf_hub_download(
554
  filename=filename,