nsfwalex commited on
Commit
4d12863
·
verified ·
1 Parent(s): bd2dc5f

Update inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +23 -6
inference_manager.py CHANGED
@@ -27,12 +27,29 @@ import re
27
  import gradio as gr
28
  import uuid
29
  from PIL import Image
30
- MAX_SEED = 12211231#np.iinfo(np.int32).max
31
  #from onediffx import compile_pipe, save_pipe, load_pipe
32
 
33
  HF_TOKEN = os.getenv('HF_TOKEN')
34
  VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
35
  DATASET_ID = 'nsfwalex/checkpoint_n_lora'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class AuthHelper:
38
  def load_public_key_from_file(self):
@@ -179,7 +196,7 @@ class InferenceManager:
179
  pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
180
  else:
181
  use_vae = cfg.get("vae", "")
182
- if not use_vae:
183
  vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
184
  elif use_vae == "tae":
185
  vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
@@ -193,7 +210,7 @@ class InferenceManager:
193
  torch_dtype=torch.bfloat16,
194
  use_safetensors=True,
195
  #variant="fp16",
196
- #custom_pipeline = "lpw_stable_diffusion_xl",
197
  )
198
  #pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
199
  clip_skip = cfg.get("clip_skip", 1)
@@ -208,7 +225,7 @@ class InferenceManager:
208
  ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
209
  if ip_ckpt:
210
  print(f"loading ip adapter model...")
211
- self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda')
212
  else:
213
  print("ip-adapter-faceid-sdxl not found, skip")
214
 
@@ -583,8 +600,8 @@ class ModelManager:
583
  generator=generator,
584
  num_images_per_prompt=1,
585
  output_type="pil",
586
- callback_on_step_end=callback_dynamic_cfg,
587
- callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
588
  ).images
589
  cost = round(time.time() - start, 2)
590
  print(f"inference done in {cost}s")
 
27
  import gradio as gr
28
  import uuid
29
  from PIL import Image
30
+ MAX_SEED = np.iinfo(np.int32).max
31
  #from onediffx import compile_pipe, save_pipe, load_pipe
32
 
33
  HF_TOKEN = os.getenv('HF_TOKEN')
34
  VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
35
  DATASET_ID = 'nsfwalex/checkpoint_n_lora'
36
+ scheduler_config = {
37
+ "num_train_timesteps": 1000,
38
+ "beta_start": 0.00085,
39
+ "beta_end": 0.012,
40
+ "beta_schedule": "scaled_linear",
41
+ "set_alpha_to_one": False,
42
+ "steps_offset": 1,
43
+ "prediction_type": "epsilon",
44
+ }
45
+ samplers = {
46
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(scheduler_config),
47
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True),
48
+ "DPM2 a": DPMSolverMultistepScheduler.from_config(scheduler_config),
49
+ "DPM++ SDE": DPMSolverSDEScheduler.from_config(scheduler_config),
50
+ "DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(scheduler_config, use_2m=True),
51
+ "DPM++ 2S a": DPMSolverMultistepScheduler.from_config(scheduler_config, use_2s=True)
52
+ }
53
 
54
  class AuthHelper:
55
  def load_public_key_from_file(self):
 
196
  pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
197
  else:
198
  use_vae = cfg.get("vae", "")
199
+ if not use_vae or True:#!TEST! default vae for test
200
  vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
201
  elif use_vae == "tae":
202
  vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
 
210
  torch_dtype=torch.bfloat16,
211
  use_safetensors=True,
212
  #variant="fp16",
213
+ custom_pipeline = "lpw_stable_diffusion_xl",
214
  )
215
  #pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
216
  clip_skip = cfg.get("clip_skip", 1)
 
225
  ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "")
226
  if ip_ckpt:
227
  print(f"loading ip adapter model...")
228
+ self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda', torch_dtype=torch.bfloat16)
229
  else:
230
  print("ip-adapter-faceid-sdxl not found, skip")
231
 
 
600
  generator=generator,
601
  num_images_per_prompt=1,
602
  output_type="pil",
603
+ #callback_on_step_end=callback_dynamic_cfg,
604
+ #callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'],
605
  ).images
606
  cost = round(time.time() - start, 2)
607
  print(f"inference done in {cost}s")