Upload folder using huggingface_hub
Browse files- inference.py +8 -8
- internals/pipelines/realtime_draw.py +10 -2
inference.py
CHANGED
@@ -727,9 +727,6 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
727 |
inpainter.init(text2img_pipe)
|
728 |
controlnet.init(text2img_pipe)
|
729 |
|
730 |
-
safety_checker.apply(text2img_pipe)
|
731 |
-
safety_checker.apply(img2img_pipe)
|
732 |
-
|
733 |
if task_type == TaskType.INPAINT:
|
734 |
inpainter.load()
|
735 |
safety_checker.apply(inpainter)
|
@@ -756,7 +753,11 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
756 |
elif task_type == TaskType.POSE:
|
757 |
controlnet.load_model("pose")
|
758 |
|
759 |
-
|
|
|
|
|
|
|
|
|
760 |
|
761 |
|
762 |
def model_fn(model_dir):
|
@@ -798,6 +799,9 @@ def predict_fn(data, pipe):
|
|
798 |
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
|
799 |
)
|
800 |
|
|
|
|
|
|
|
801 |
# Realtime generation apis
|
802 |
if task_type == TaskType.RT_DRAW_SEG:
|
803 |
return rt_draw_seg(task)
|
@@ -814,10 +818,6 @@ def predict_fn(data, pipe):
|
|
814 |
avatar.fetch_from_network(task.get_model_id())
|
815 |
|
816 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
817 |
-
# character sheet
|
818 |
-
# if "character sheet" in task.get_prompt().lower():
|
819 |
-
# return pose(task, s3_outkey="", poses=pickPoses())
|
820 |
-
# else:
|
821 |
return text2img(task)
|
822 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
823 |
return img2img(task)
|
|
|
727 |
inpainter.init(text2img_pipe)
|
728 |
controlnet.init(text2img_pipe)
|
729 |
|
|
|
|
|
|
|
730 |
if task_type == TaskType.INPAINT:
|
731 |
inpainter.load()
|
732 |
safety_checker.apply(inpainter)
|
|
|
753 |
elif task_type == TaskType.POSE:
|
754 |
controlnet.load_model("pose")
|
755 |
|
756 |
+
|
757 |
+
def apply_safety_checkers():
|
758 |
+
safety_checker.apply(text2img_pipe)
|
759 |
+
safety_checker.apply(img2img_pipe)
|
760 |
+
safety_checker.apply(controlnet)
|
761 |
|
762 |
|
763 |
def model_fn(model_dir):
|
|
|
799 |
task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id()
|
800 |
)
|
801 |
|
802 |
+
# Apply safety checkers
|
803 |
+
apply_safety_checkers()
|
804 |
+
|
805 |
# Realtime generation apis
|
806 |
if task_type == TaskType.RT_DRAW_SEG:
|
807 |
return rt_draw_seg(task)
|
|
|
818 |
avatar.fetch_from_network(task.get_model_id())
|
819 |
|
820 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
|
|
|
|
|
|
|
|
821 |
return text2img(task)
|
822 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
823 |
return img2img(task)
|
internals/pipelines/realtime_draw.py
CHANGED
@@ -72,10 +72,16 @@ class RealtimeDraw(AbstractPipeline):
|
|
72 |
torch.manual_seed(seed)
|
73 |
|
74 |
if not image:
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
if not image2:
|
78 |
-
|
|
|
|
|
|
|
79 |
|
80 |
image = ImageUtil.resize_image(image, 512)
|
81 |
|
@@ -91,6 +97,8 @@ class RealtimeDraw(AbstractPipeline):
|
|
91 |
negative_prompt=negative_prompt,
|
92 |
guidance_scale=10,
|
93 |
strength=0.9,
|
|
|
|
|
94 |
controlnet_conditioning_scale=[1.0, 0.8],
|
95 |
).images[0]
|
96 |
|
|
|
72 |
torch.manual_seed(seed)
|
73 |
|
74 |
if not image:
|
75 |
+
size = (512, 512)
|
76 |
+
if image2:
|
77 |
+
size = image2.size
|
78 |
+
image = Image.new("RGB", size, color=0)
|
79 |
|
80 |
if not image2:
|
81 |
+
size = (512, 512)
|
82 |
+
if image:
|
83 |
+
size = image.size
|
84 |
+
image2 = Image.new("RGB", size, color=0)
|
85 |
|
86 |
image = ImageUtil.resize_image(image, 512)
|
87 |
|
|
|
97 |
negative_prompt=negative_prompt,
|
98 |
guidance_scale=10,
|
99 |
strength=0.9,
|
100 |
+
width=image.size[0],
|
101 |
+
height=image.size[1],
|
102 |
controlnet_conditioning_scale=[1.0, 0.8],
|
103 |
).images[0]
|
104 |
|