jayparmr commited on
Commit
3e9c18d
1 Parent(s): 99a0484

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -727,9 +727,6 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
727
  inpainter.init(text2img_pipe)
728
  controlnet.init(text2img_pipe)
729
 
730
- safety_checker.apply(text2img_pipe)
731
- safety_checker.apply(img2img_pipe)
732
-
733
  if task_type == TaskType.INPAINT:
734
  inpainter.load()
735
  safety_checker.apply(inpainter)
@@ -756,7 +753,11 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
756
  elif task_type == TaskType.POSE:
757
  controlnet.load_model("pose")
758
 
759
- safety_checker.apply(controlnet)
 
 
 
 
760
 
761
 
762
  def model_fn(model_dir):
@@ -798,6 +799,9 @@ def predict_fn(data, pipe):
798
  task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
799
  )
800
 
 
 
 
801
  # Realtime generation apis
802
  if task_type == TaskType.RT_DRAW_SEG:
803
  return rt_draw_seg(task)
@@ -814,10 +818,6 @@ def predict_fn(data, pipe):
814
  avatar.fetch_from_network(task.get_model_id())
815
 
816
  if task_type == TaskType.TEXT_TO_IMAGE:
817
- # character sheet
818
- # if "character sheet" in task.get_prompt().lower():
819
- # return pose(task, s3_outkey="", poses=pickPoses())
820
- # else:
821
  return text2img(task)
822
  elif task_type == TaskType.IMAGE_TO_IMAGE:
823
  return img2img(task)
 
727
  inpainter.init(text2img_pipe)
728
  controlnet.init(text2img_pipe)
729
 
 
 
 
730
  if task_type == TaskType.INPAINT:
731
  inpainter.load()
732
  safety_checker.apply(inpainter)
 
753
  elif task_type == TaskType.POSE:
754
  controlnet.load_model("pose")
755
 
756
+
757
+ def apply_safety_checkers():
758
+ safety_checker.apply(text2img_pipe)
759
+ safety_checker.apply(img2img_pipe)
760
+ safety_checker.apply(controlnet)
761
 
762
 
763
  def model_fn(model_dir):
 
799
  task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
800
  )
801
 
802
+ # Apply safety checkers
803
+ apply_safety_checkers()
804
+
805
  # Realtime generation apis
806
  if task_type == TaskType.RT_DRAW_SEG:
807
  return rt_draw_seg(task)
 
818
  avatar.fetch_from_network(task.get_model_id())
819
 
820
  if task_type == TaskType.TEXT_TO_IMAGE:
 
 
 
 
821
  return text2img(task)
822
  elif task_type == TaskType.IMAGE_TO_IMAGE:
823
  return img2img(task)
internals/pipelines/realtime_draw.py CHANGED
@@ -72,10 +72,16 @@ class RealtimeDraw(AbstractPipeline):
72
  torch.manual_seed(seed)
73
 
74
  if not image:
75
- image = Image.new("RGB", (512, 512), color=0)
 
 
 
76
 
77
  if not image2:
78
- image2 = Image.new("RGB", image.size, color=0)
 
 
 
79
 
80
  image = ImageUtil.resize_image(image, 512)
81
 
@@ -91,6 +97,8 @@ class RealtimeDraw(AbstractPipeline):
91
  negative_prompt=negative_prompt,
92
  guidance_scale=10,
93
  strength=0.9,
 
 
94
  controlnet_conditioning_scale=[1.0, 0.8],
95
  ).images[0]
96
 
 
72
  torch.manual_seed(seed)
73
 
74
  if not image:
75
+ size = (512, 512)
76
+ if image2:
77
+ size = image2.size
78
+ image = Image.new("RGB", size, color=0)
79
 
80
  if not image2:
81
+ size = (512, 512)
82
+ if image:
83
+ size = image.size
84
+ image2 = Image.new("RGB", size, color=0)
85
 
86
  image = ImageUtil.resize_image(image, 512)
87
 
 
97
  negative_prompt=negative_prompt,
98
  guidance_scale=10,
99
  strength=0.9,
100
+ width=image.size[0],
101
+ height=image.size[1],
102
  controlnet_conditioning_scale=[1.0, 0.8],
103
  ).images[0]
104