nsfwalex commited on
Commit
a0b5cec
·
verified ·
1 Parent(s): 101f7e2

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +4 -5
inference_manager.py CHANGED
@@ -313,7 +313,9 @@ class ModelManager:
313
  :param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
314
  """
315
  print("downloading models")
316
- print("downloading antelopev2...")
 
 
317
  #download_from_hf()
318
  self.ext_model_pathes = {
319
  "ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
@@ -440,14 +442,11 @@ class ModelManager:
440
  raise Exception(f"face images not provided")
441
  start = time.time()
442
  model.base_model_pipeline.to("cuda")
443
- print("loading face analysis...")
444
- app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
445
- app.prepare(ctx_id=0, det_size=(512, 512))
446
 
447
  faceid_all_embeds = []
448
  for image in images:
449
  face = cv2.imread(image)
450
- faces = app.get(face)
451
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
452
  faceid_all_embeds.append(faceid_embed)
453
 
 
313
  :param model_directory: The directory to scan for model config files (e.g., "/path/to/models").
314
  """
315
  print("downloading models")
316
+ print("loading face analysis...")
317
+ self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
318
+ self.app.prepare(ctx_id=0, det_size=(512, 512))
319
  #download_from_hf()
320
  self.ext_model_pathes = {
321
  "ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model")
 
442
  raise Exception(f"face images not provided")
443
  start = time.time()
444
  model.base_model_pipeline.to("cuda")
 
 
 
445
 
446
  faceid_all_embeds = []
447
  for image in images:
448
  face = cv2.imread(image)
449
+ faces = self.app.get(face)
450
  faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
451
  faceid_all_embeds.append(faceid_embed)
452