Upload folder using huggingface_hub
Browse files- inference.py +37 -35
- inference2.py +15 -9
- internals/data/dataAccessor.py +27 -10
- internals/pipelines/commons.py +60 -16
- internals/pipelines/controlnets.py +132 -130
- internals/pipelines/high_res.py +1 -1
- internals/pipelines/inpainter.py +48 -12
- internals/pipelines/remove_background.py +54 -9
- internals/pipelines/replace_background.py +17 -7
- internals/pipelines/twoStepPipeline.py +1 -1
- internals/util/cache.py +13 -3
- internals/util/commons.py +2 -2
- internals/util/config.py +5 -0
- internals/util/lora_style.py +5 -0
- internals/util/model_loader.py +3 -0
- pyproject.toml +1 -1
- requirements.txt +4 -0
inference.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import os
|
2 |
from typing import List, Optional
|
3 |
|
|
|
4 |
import torch
|
5 |
|
6 |
import internals.util.prompt as prompt_util
|
7 |
-
from internals.data.dataAccessor import update_db
|
8 |
from internals.data.task import Task, TaskType
|
9 |
from internals.pipelines.commons import Img2Img, Text2Img
|
10 |
from internals.pipelines.controlnets import ControlNet
|
@@ -18,11 +19,15 @@ from internals.pipelines.replace_background import ReplaceBackground
|
|
18 |
from internals.pipelines.safety_checker import SafetyChecker
|
19 |
from internals.util.args import apply_style_args
|
20 |
from internals.util.avatar import Avatar
|
21 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
22 |
from internals.util.commons import download_image, upload_image, upload_images
|
23 |
-
from internals.util.config import (
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
from internals.util.failure_hander import FailureHandler
|
27 |
from internals.util.lora_style import LoraStyle
|
28 |
from internals.util.model_loader import load_model_from_config
|
@@ -80,7 +85,7 @@ def canny(task: Task):
|
|
80 |
|
81 |
width, height = get_intermediate_dimension(task)
|
82 |
|
83 |
-
controlnet.
|
84 |
|
85 |
# pipe2 is used for canny and pose
|
86 |
lora_patcher = lora_style.get_patcher(
|
@@ -88,7 +93,7 @@ def canny(task: Task):
|
|
88 |
)
|
89 |
lora_patcher.patch()
|
90 |
|
91 |
-
images, has_nsfw = controlnet.
|
92 |
prompt=prompt,
|
93 |
imageUrl=task.get_imageUrl(),
|
94 |
seed=task.get_seed(),
|
@@ -132,12 +137,12 @@ def tile_upscale(task: Task):
|
|
132 |
|
133 |
prompt = get_patched_prompt_tile_upscale(task)
|
134 |
|
135 |
-
controlnet.
|
136 |
|
137 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
138 |
lora_patcher.patch()
|
139 |
|
140 |
-
images, has_nsfw = controlnet.
|
141 |
imageUrl=task.get_imageUrl(),
|
142 |
seed=task.get_seed(),
|
143 |
steps=task.get_steps(),
|
@@ -169,14 +174,14 @@ def scribble(task: Task):
|
|
169 |
|
170 |
width, height = get_intermediate_dimension(task)
|
171 |
|
172 |
-
controlnet.
|
173 |
|
174 |
lora_patcher = lora_style.get_patcher(
|
175 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
176 |
)
|
177 |
lora_patcher.patch()
|
178 |
|
179 |
-
images, has_nsfw = controlnet.
|
180 |
imageUrl=task.get_imageUrl(),
|
181 |
seed=task.get_seed(),
|
182 |
steps=task.get_steps(),
|
@@ -215,14 +220,14 @@ def linearart(task: Task):
|
|
215 |
|
216 |
width, height = get_intermediate_dimension(task)
|
217 |
|
218 |
-
controlnet.
|
219 |
|
220 |
lora_patcher = lora_style.get_patcher(
|
221 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
222 |
)
|
223 |
lora_patcher.patch()
|
224 |
|
225 |
-
images, has_nsfw = controlnet.
|
226 |
imageUrl=task.get_imageUrl(),
|
227 |
seed=task.get_seed(),
|
228 |
steps=task.get_steps(),
|
@@ -261,7 +266,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
261 |
|
262 |
width, height = get_intermediate_dimension(task)
|
263 |
|
264 |
-
controlnet.
|
265 |
|
266 |
# pipe2 is used for canny and pose
|
267 |
lora_patcher = lora_style.get_patcher(
|
@@ -291,7 +296,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
291 |
)
|
292 |
condition_image = ControlNet.linearart_condition_image(src_image)
|
293 |
|
294 |
-
images, has_nsfw = controlnet.
|
295 |
prompt=prompt,
|
296 |
image=poses,
|
297 |
condition_image=[condition_image] * num_return_sequences,
|
@@ -440,7 +445,7 @@ def inpaint(task: Task):
|
|
440 |
|
441 |
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
442 |
|
443 |
-
|
444 |
|
445 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
446 |
|
@@ -469,12 +474,13 @@ def replace_bg(task: Task):
|
|
469 |
product_scale_width=task.get_image_scale(),
|
470 |
apply_high_res=task.get_high_res_fix(),
|
471 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
|
|
472 |
)
|
473 |
|
474 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
475 |
|
476 |
lora_patcher.cleanup()
|
477 |
-
|
478 |
|
479 |
return {
|
480 |
"modified_prompts": prompt,
|
@@ -484,38 +490,33 @@ def replace_bg(task: Task):
|
|
484 |
|
485 |
|
486 |
def load_model_by_task(task: Task):
|
487 |
-
|
488 |
-
|
489 |
-
if (
|
490 |
-
task.get_type()
|
491 |
-
in [
|
492 |
-
TaskType.TEXT_TO_IMAGE,
|
493 |
-
TaskType.IMAGE_TO_IMAGE,
|
494 |
-
TaskType.INPAINT,
|
495 |
-
]
|
496 |
-
and not text2img_pipe.is_loaded()
|
497 |
-
):
|
498 |
text2img_pipe.load(get_model_dir())
|
499 |
img2img_pipe.create(text2img_pipe)
|
500 |
-
inpainter.load()
|
501 |
high_res.load(img2img_pipe)
|
502 |
|
|
|
|
|
|
|
503 |
safety_checker.apply(text2img_pipe)
|
504 |
safety_checker.apply(img2img_pipe)
|
|
|
|
|
|
|
505 |
safety_checker.apply(inpainter)
|
506 |
elif task.get_type() == TaskType.REPLACE_BG:
|
507 |
replace_background.load(inpainter=inpainter, high_res=high_res)
|
508 |
else:
|
509 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
510 |
-
controlnet.
|
511 |
elif task.get_type() == TaskType.CANNY:
|
512 |
-
controlnet.
|
513 |
elif task.get_type() == TaskType.SCRIBBLE:
|
514 |
-
controlnet.
|
515 |
elif task.get_type() == TaskType.LINEARART:
|
516 |
-
controlnet.
|
517 |
elif task.get_type() == TaskType.POSE:
|
518 |
-
controlnet.
|
519 |
|
520 |
safety_checker.apply(controlnet)
|
521 |
|
@@ -589,7 +590,8 @@ def predict_fn(data, pipe):
|
|
589 |
else:
|
590 |
raise Exception("Invalid task type")
|
591 |
except Exception as e:
|
592 |
-
print(f"Error: {e}")
|
593 |
slack.error_alert(task, e)
|
594 |
controlnet.cleanup()
|
|
|
|
|
595 |
return None
|
|
|
1 |
import os
|
2 |
from typing import List, Optional
|
3 |
|
4 |
+
import traceback
|
5 |
import torch
|
6 |
|
7 |
import internals.util.prompt as prompt_util
|
8 |
+
from internals.data.dataAccessor import update_db, update_db_source_failed
|
9 |
from internals.data.task import Task, TaskType
|
10 |
from internals.pipelines.commons import Img2Img, Text2Img
|
11 |
from internals.pipelines.controlnets import ControlNet
|
|
|
19 |
from internals.pipelines.safety_checker import SafetyChecker
|
20 |
from internals.util.args import apply_style_args
|
21 |
from internals.util.avatar import Avatar
|
22 |
+
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
|
23 |
from internals.util.commons import download_image, upload_image, upload_images
|
24 |
+
from internals.util.config import (
|
25 |
+
get_model_dir,
|
26 |
+
num_return_sequences,
|
27 |
+
set_configs_from_task,
|
28 |
+
set_model_config,
|
29 |
+
set_root_dir,
|
30 |
+
)
|
31 |
from internals.util.failure_hander import FailureHandler
|
32 |
from internals.util.lora_style import LoraStyle
|
33 |
from internals.util.model_loader import load_model_from_config
|
|
|
85 |
|
86 |
width, height = get_intermediate_dimension(task)
|
87 |
|
88 |
+
controlnet.load_model("canny")
|
89 |
|
90 |
# pipe2 is used for canny and pose
|
91 |
lora_patcher = lora_style.get_patcher(
|
|
|
93 |
)
|
94 |
lora_patcher.patch()
|
95 |
|
96 |
+
images, has_nsfw = controlnet.process(
|
97 |
prompt=prompt,
|
98 |
imageUrl=task.get_imageUrl(),
|
99 |
seed=task.get_seed(),
|
|
|
137 |
|
138 |
prompt = get_patched_prompt_tile_upscale(task)
|
139 |
|
140 |
+
controlnet.load_model("tile_upscaler")
|
141 |
|
142 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
143 |
lora_patcher.patch()
|
144 |
|
145 |
+
images, has_nsfw = controlnet.process(
|
146 |
imageUrl=task.get_imageUrl(),
|
147 |
seed=task.get_seed(),
|
148 |
steps=task.get_steps(),
|
|
|
174 |
|
175 |
width, height = get_intermediate_dimension(task)
|
176 |
|
177 |
+
controlnet.load_model("scribble")
|
178 |
|
179 |
lora_patcher = lora_style.get_patcher(
|
180 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
181 |
)
|
182 |
lora_patcher.patch()
|
183 |
|
184 |
+
images, has_nsfw = controlnet.process(
|
185 |
imageUrl=task.get_imageUrl(),
|
186 |
seed=task.get_seed(),
|
187 |
steps=task.get_steps(),
|
|
|
220 |
|
221 |
width, height = get_intermediate_dimension(task)
|
222 |
|
223 |
+
controlnet.load_model("linearart")
|
224 |
|
225 |
lora_patcher = lora_style.get_patcher(
|
226 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
227 |
)
|
228 |
lora_patcher.patch()
|
229 |
|
230 |
+
images, has_nsfw = controlnet.process(
|
231 |
imageUrl=task.get_imageUrl(),
|
232 |
seed=task.get_seed(),
|
233 |
steps=task.get_steps(),
|
|
|
266 |
|
267 |
width, height = get_intermediate_dimension(task)
|
268 |
|
269 |
+
controlnet.load_model("pose")
|
270 |
|
271 |
# pipe2 is used for canny and pose
|
272 |
lora_patcher = lora_style.get_patcher(
|
|
|
296 |
)
|
297 |
condition_image = ControlNet.linearart_condition_image(src_image)
|
298 |
|
299 |
+
images, has_nsfw = controlnet.process(
|
300 |
prompt=prompt,
|
301 |
image=poses,
|
302 |
condition_image=[condition_image] * num_return_sequences,
|
|
|
445 |
|
446 |
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
447 |
|
448 |
+
clear_cuda_and_gc()
|
449 |
|
450 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
451 |
|
|
|
474 |
product_scale_width=task.get_image_scale(),
|
475 |
apply_high_res=task.get_high_res_fix(),
|
476 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
477 |
+
model_type=task.get_modelType(),
|
478 |
)
|
479 |
|
480 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
481 |
|
482 |
lora_patcher.cleanup()
|
483 |
+
clear_cuda_and_gc()
|
484 |
|
485 |
return {
|
486 |
"modified_prompts": prompt,
|
|
|
490 |
|
491 |
|
492 |
def load_model_by_task(task: Task):
|
493 |
+
if not text2img_pipe.is_loaded():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
text2img_pipe.load(get_model_dir())
|
495 |
img2img_pipe.create(text2img_pipe)
|
|
|
496 |
high_res.load(img2img_pipe)
|
497 |
|
498 |
+
inpainter.init(text2img_pipe)
|
499 |
+
controlnet.init(text2img_pipe)
|
500 |
+
|
501 |
safety_checker.apply(text2img_pipe)
|
502 |
safety_checker.apply(img2img_pipe)
|
503 |
+
|
504 |
+
if task.get_type() == TaskType.INPAINT:
|
505 |
+
inpainter.load()
|
506 |
safety_checker.apply(inpainter)
|
507 |
elif task.get_type() == TaskType.REPLACE_BG:
|
508 |
replace_background.load(inpainter=inpainter, high_res=high_res)
|
509 |
else:
|
510 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
511 |
+
controlnet.load_model("tile_upscaler")
|
512 |
elif task.get_type() == TaskType.CANNY:
|
513 |
+
controlnet.load_model("canny")
|
514 |
elif task.get_type() == TaskType.SCRIBBLE:
|
515 |
+
controlnet.load_model("scribble")
|
516 |
elif task.get_type() == TaskType.LINEARART:
|
517 |
+
controlnet.load_model("linearart")
|
518 |
elif task.get_type() == TaskType.POSE:
|
519 |
+
controlnet.load_model("pose")
|
520 |
|
521 |
safety_checker.apply(controlnet)
|
522 |
|
|
|
590 |
else:
|
591 |
raise Exception("Invalid task type")
|
592 |
except Exception as e:
|
|
|
593 |
slack.error_alert(task, e)
|
594 |
controlnet.cleanup()
|
595 |
+
traceback.print_exc()
|
596 |
+
update_db_source_failed(task.get_sourceId(), task.get_userId())
|
597 |
return None
|
inference2.py
CHANGED
@@ -13,17 +13,19 @@ from internals.pipelines.img_to_text import Image2Text
|
|
13 |
from internals.pipelines.inpainter import InPainter
|
14 |
from internals.pipelines.object_remove import ObjectRemoval
|
15 |
from internals.pipelines.prompt_modifier import PromptModifier
|
16 |
-
from internals.pipelines.remove_background import
|
17 |
-
RemoveBackgroundV2)
|
18 |
from internals.pipelines.replace_background import ReplaceBackground
|
19 |
from internals.pipelines.safety_checker import SafetyChecker
|
20 |
from internals.pipelines.upscaler import Upscaler
|
21 |
from internals.util.avatar import Avatar
|
22 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
23 |
-
from internals.util.commons import
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
from internals.util.failure_hander import FailureHandler
|
28 |
from internals.util.lora_style import LoraStyle
|
29 |
from internals.util.model_loader import load_model_from_config
|
@@ -65,7 +67,7 @@ def tile_upscale(task: Task):
|
|
65 |
|
66 |
prompt = get_patched_prompt_tile_upscale(task)
|
67 |
|
68 |
-
controlnet.
|
69 |
|
70 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
71 |
lora_patcher.patch()
|
@@ -98,7 +100,9 @@ def tile_upscale(task: Task):
|
|
98 |
@slack.auto_send_alert
|
99 |
def remove_bg(task: Task):
|
100 |
# remove_background = RemoveBackground()
|
101 |
-
output_image = remove_background_v2.remove(
|
|
|
|
|
102 |
|
103 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
104 |
upload_image(output_image, output_key)
|
@@ -173,6 +177,7 @@ def replace_bg(task: Task):
|
|
173 |
extend_object=task.rbg_extend_object(),
|
174 |
product_scale_width=task.get_image_scale(),
|
175 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
|
|
176 |
)
|
177 |
|
178 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
@@ -231,6 +236,7 @@ def model_fn(model_dir):
|
|
231 |
upscaler.load()
|
232 |
inpainter.load()
|
233 |
high_res.load()
|
|
|
234 |
|
235 |
replace_background.load(
|
236 |
upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
|
@@ -242,7 +248,7 @@ def model_fn(model_dir):
|
|
242 |
|
243 |
def load_model_by_task(task: Task):
|
244 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
245 |
-
controlnet.
|
246 |
|
247 |
safety_checker.apply(controlnet)
|
248 |
|
|
|
13 |
from internals.pipelines.inpainter import InPainter
|
14 |
from internals.pipelines.object_remove import ObjectRemoval
|
15 |
from internals.pipelines.prompt_modifier import PromptModifier
|
16 |
+
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
|
|
|
17 |
from internals.pipelines.replace_background import ReplaceBackground
|
18 |
from internals.pipelines.safety_checker import SafetyChecker
|
19 |
from internals.pipelines.upscaler import Upscaler
|
20 |
from internals.util.avatar import Avatar
|
21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
22 |
+
from internals.util.commons import construct_default_s3_url, upload_image, upload_images
|
23 |
+
from internals.util.config import (
|
24 |
+
num_return_sequences,
|
25 |
+
set_configs_from_task,
|
26 |
+
set_model_config,
|
27 |
+
set_root_dir,
|
28 |
+
)
|
29 |
from internals.util.failure_hander import FailureHandler
|
30 |
from internals.util.lora_style import LoraStyle
|
31 |
from internals.util.model_loader import load_model_from_config
|
|
|
67 |
|
68 |
prompt = get_patched_prompt_tile_upscale(task)
|
69 |
|
70 |
+
controlnet.load_model("tile_upscaler")
|
71 |
|
72 |
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
73 |
lora_patcher.patch()
|
|
|
100 |
@slack.auto_send_alert
|
101 |
def remove_bg(task: Task):
|
102 |
# remove_background = RemoveBackground()
|
103 |
+
output_image = remove_background_v2.remove(
|
104 |
+
task.get_imageUrl(), model_type=task.get_modelType()
|
105 |
+
)
|
106 |
|
107 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
108 |
upload_image(output_image, output_key)
|
|
|
177 |
extend_object=task.rbg_extend_object(),
|
178 |
product_scale_width=task.get_image_scale(),
|
179 |
conditioning_scale=task.rbg_controlnet_conditioning_scale(),
|
180 |
+
model_type=task.get_modelType(),
|
181 |
)
|
182 |
|
183 |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
|
|
|
236 |
upscaler.load()
|
237 |
inpainter.load()
|
238 |
high_res.load()
|
239 |
+
controlnet.init(high_res)
|
240 |
|
241 |
replace_background.load(
|
242 |
upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
|
|
|
248 |
|
249 |
def load_model_by_task(task: Task):
|
250 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
251 |
+
controlnet.load_model("tile_upscaler")
|
252 |
|
253 |
safety_checker.apply(controlnet)
|
254 |
|
internals/data/dataAccessor.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import traceback
|
2 |
from typing import Dict, List, Optional
|
3 |
|
|
|
4 |
import requests
|
5 |
from pydash import includes
|
6 |
|
@@ -9,6 +10,14 @@ from internals.util.config import api_endpoint, api_headers
|
|
9 |
from internals.util.slack import Slack
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def updateSource(sourceId, userId, state):
|
13 |
print("update source is called")
|
14 |
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
|
@@ -21,7 +30,8 @@ def updateSource(sourceId, userId, state):
|
|
21 |
data = {"state": state}
|
22 |
|
23 |
try:
|
24 |
-
|
|
|
25 |
print("update source response", response)
|
26 |
except requests.exceptions.Timeout:
|
27 |
print("Request timed out while updating source")
|
@@ -47,7 +57,8 @@ def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
|
|
47 |
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
|
48 |
|
49 |
try:
|
50 |
-
|
|
|
51 |
# print("save generation response", response)
|
52 |
except requests.exceptions.Timeout:
|
53 |
print("Request timed out while saving image")
|
@@ -61,11 +72,12 @@ def getStyles() -> Optional[Dict]:
|
|
61 |
url = api_endpoint() + "/autodraft-crecoai/style"
|
62 |
print(url)
|
63 |
try:
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
return response.json()
|
70 |
except requests.exceptions.Timeout:
|
71 |
print("Request timed out while fetching styles")
|
@@ -78,9 +90,10 @@ def getStyles() -> Optional[Dict]:
|
|
78 |
def getCharacters(model_id: str) -> Optional[List]:
|
79 |
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
|
80 |
try:
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
return response
|
85 |
except requests.exceptions.Timeout:
|
86 |
print("Request timed out while fetching characters")
|
@@ -89,6 +102,10 @@ def getCharacters(model_id: str) -> Optional[List]:
|
|
89 |
return None
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
92 |
def update_db(func):
|
93 |
def caller(*args, **kwargs):
|
94 |
if type(args[0]) is not Task:
|
|
|
1 |
import traceback
|
2 |
from typing import Dict, List, Optional
|
3 |
|
4 |
+
from requests.adapters import Retry, HTTPAdapter
|
5 |
import requests
|
6 |
from pydash import includes
|
7 |
|
|
|
10 |
from internals.util.slack import Slack
|
11 |
|
12 |
|
13 |
+
class RetryRequest:
|
14 |
+
def __new__(cls):
|
15 |
+
obj = Retry(total=5, backoff_factor=2, status_forcelist=[500, 502, 503, 504])
|
16 |
+
session = requests.Session()
|
17 |
+
session.mount("https://", HTTPAdapter(max_retries=obj))
|
18 |
+
return session
|
19 |
+
|
20 |
+
|
21 |
def updateSource(sourceId, userId, state):
|
22 |
print("update source is called")
|
23 |
url = api_endpoint() + f"/autodraft-crecoai/source/{sourceId}"
|
|
|
30 |
data = {"state": state}
|
31 |
|
32 |
try:
|
33 |
+
with RetryRequest() as session:
|
34 |
+
response = session.patch(url, headers=headers, json=data, timeout=10)
|
35 |
print("update source response", response)
|
36 |
except requests.exceptions.Timeout:
|
37 |
print("Request timed out while updating source")
|
|
|
57 |
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
|
58 |
|
59 |
try:
|
60 |
+
with RetryRequest() as session:
|
61 |
+
session.patch(url, headers=headers, json=data)
|
62 |
# print("save generation response", response)
|
63 |
except requests.exceptions.Timeout:
|
64 |
print("Request timed out while saving image")
|
|
|
72 |
url = api_endpoint() + "/autodraft-crecoai/style"
|
73 |
print(url)
|
74 |
try:
|
75 |
+
with RetryRequest() as session:
|
76 |
+
response = session.get(
|
77 |
+
url,
|
78 |
+
timeout=10,
|
79 |
+
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
|
80 |
+
)
|
81 |
return response.json()
|
82 |
except requests.exceptions.Timeout:
|
83 |
print("Request timed out while fetching styles")
|
|
|
90 |
def getCharacters(model_id: str) -> Optional[List]:
|
91 |
url = api_endpoint() + "/autodraft-crecoai/model/{}".format(model_id)
|
92 |
try:
|
93 |
+
with RetryRequest() as session:
|
94 |
+
response = session.get(url, timeout=10, headers=api_headers())
|
95 |
+
response = response.json()
|
96 |
+
response = response["data"]["characters"]
|
97 |
return response
|
98 |
except requests.exceptions.Timeout:
|
99 |
print("Request timed out while fetching characters")
|
|
|
102 |
return None
|
103 |
|
104 |
|
105 |
+
def update_db_source_failed(sourceId, userId):
|
106 |
+
updateSource(sourceId, userId, "FAILED")
|
107 |
+
|
108 |
+
|
109 |
def update_db(func):
|
110 |
def caller(*args, **kwargs):
|
111 |
if type(args[0]) is not Task:
|
internals/pipelines/commons.py
CHANGED
@@ -2,12 +2,16 @@ from dataclasses import dataclass
|
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from internals.data.result import Result
|
8 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
9 |
from internals.util.commons import disable_safety_checker, download_image
|
10 |
-
from internals.util.config import get_hf_token, num_return_sequences
|
11 |
|
12 |
|
13 |
class AbstractPipeline:
|
@@ -27,9 +31,17 @@ class Text2Img(AbstractPipeline):
|
|
27 |
prompt_right: List[str] = None
|
28 |
|
29 |
def load(self, model_dir: str):
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self.__patch()
|
34 |
|
35 |
def is_loaded(self):
|
@@ -38,10 +50,16 @@ class Text2Img(AbstractPipeline):
|
|
38 |
return False
|
39 |
|
40 |
def create(self, pipeline: AbstractPipeline):
|
41 |
-
|
|
|
|
|
|
|
42 |
self.__patch()
|
43 |
|
44 |
def __patch(self):
|
|
|
|
|
|
|
45 |
self.pipe.enable_xformers_memory_efficient_attention()
|
46 |
|
47 |
@torch.inference_mode()
|
@@ -92,9 +110,19 @@ class Text2Img(AbstractPipeline):
|
|
92 |
# two step pipeline
|
93 |
modified_prompt = params.modified_prompt
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
height=height,
|
99 |
width=width,
|
100 |
num_inference_steps=num_inference_steps,
|
@@ -111,7 +139,7 @@ class Text2Img(AbstractPipeline):
|
|
111 |
callback=callback,
|
112 |
callback_steps=callback_steps,
|
113 |
cross_attention_kwargs=cross_attention_kwargs,
|
114 |
-
|
115 |
)
|
116 |
|
117 |
return Result.from_result(result)
|
@@ -124,22 +152,38 @@ class Img2Img(AbstractPipeline):
|
|
124 |
if self.__loaded:
|
125 |
return
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
self.__patch()
|
131 |
|
132 |
self.__loaded = True
|
133 |
|
134 |
def create(self, pipeline: AbstractPipeline):
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
138 |
self.__patch()
|
139 |
|
140 |
self.__loaded = True
|
141 |
|
142 |
def __patch(self):
|
|
|
|
|
|
|
143 |
self.pipe.enable_xformers_memory_efficient_attention()
|
144 |
|
145 |
@torch.inference_mode()
|
|
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
+
from diffusers import (
|
6 |
+
StableDiffusionImg2ImgPipeline,
|
7 |
+
StableDiffusionXLPipeline,
|
8 |
+
StableDiffusionXLImg2ImgPipeline,
|
9 |
+
)
|
10 |
|
11 |
from internals.data.result import Result
|
12 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
13 |
from internals.util.commons import disable_safety_checker, download_image
|
14 |
+
from internals.util.config import get_hf_token, num_return_sequences, get_is_sdxl
|
15 |
|
16 |
|
17 |
class AbstractPipeline:
|
|
|
31 |
prompt_right: List[str] = None
|
32 |
|
33 |
def load(self, model_dir: str):
|
34 |
+
if get_is_sdxl():
|
35 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
36 |
+
model_dir,
|
37 |
+
torch_dtype=torch.float16,
|
38 |
+
use_auth_token=get_hf_token(),
|
39 |
+
use_safetensors=True,
|
40 |
+
).to("cuda")
|
41 |
+
else:
|
42 |
+
self.pipe = two_step_pipeline.from_pretrained(
|
43 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
44 |
+
).to("cuda")
|
45 |
self.__patch()
|
46 |
|
47 |
def is_loaded(self):
|
|
|
50 |
return False
|
51 |
|
52 |
def create(self, pipeline: AbstractPipeline):
|
53 |
+
if get_is_sdxl():
|
54 |
+
self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
|
55 |
+
else:
|
56 |
+
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
|
57 |
self.__patch()
|
58 |
|
59 |
def __patch(self):
|
60 |
+
if get_is_sdxl():
|
61 |
+
self.pipe.enable_vae_tiling()
|
62 |
+
self.pipe.enable_vae_slicing()
|
63 |
self.pipe.enable_xformers_memory_efficient_attention()
|
64 |
|
65 |
@torch.inference_mode()
|
|
|
110 |
# two step pipeline
|
111 |
modified_prompt = params.modified_prompt
|
112 |
|
113 |
+
if get_is_sdxl():
|
114 |
+
print("Warning: Two step pipeline is not supported on SDXL")
|
115 |
+
kwargs = {
|
116 |
+
"prompt": modified_prompt,
|
117 |
+
}
|
118 |
+
else:
|
119 |
+
kwargs = {
|
120 |
+
"prompt": prompt,
|
121 |
+
"modified_prompts": modified_prompt,
|
122 |
+
"iteration": iteration,
|
123 |
+
}
|
124 |
+
|
125 |
+
result = self.pipe.__call__(
|
126 |
height=height,
|
127 |
width=width,
|
128 |
num_inference_steps=num_inference_steps,
|
|
|
139 |
callback=callback,
|
140 |
callback_steps=callback_steps,
|
141 |
cross_attention_kwargs=cross_attention_kwargs,
|
142 |
+
**kwargs
|
143 |
)
|
144 |
|
145 |
return Result.from_result(result)
|
|
|
152 |
if self.__loaded:
|
153 |
return
|
154 |
|
155 |
+
if get_is_sdxl():
|
156 |
+
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
157 |
+
model_dir,
|
158 |
+
torch_dtype=torch.float16,
|
159 |
+
use_auth_token=get_hf_token(),
|
160 |
+
use_safetensors=True,
|
161 |
+
).to("cuda")
|
162 |
+
else:
|
163 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
164 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
165 |
+
).to("cuda")
|
166 |
self.__patch()
|
167 |
|
168 |
self.__loaded = True
|
169 |
|
170 |
def create(self, pipeline: AbstractPipeline):
|
171 |
+
if get_is_sdxl():
|
172 |
+
self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
|
173 |
+
"cuda"
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
|
177 |
+
"cuda"
|
178 |
+
)
|
179 |
self.__patch()
|
180 |
|
181 |
self.__loaded = True
|
182 |
|
183 |
def __patch(self):
|
184 |
+
if get_is_sdxl():
|
185 |
+
self.pipe.enable_vae_tiling()
|
186 |
+
self.pipe.enable_vae_slicing()
|
187 |
self.pipe.enable_xformers_memory_efficient_attention()
|
188 |
|
189 |
@torch.inference_mode()
|
internals/pipelines/controlnets.py
CHANGED
@@ -1,14 +1,20 @@
|
|
1 |
-
from typing import List, Union
|
2 |
|
3 |
import cv2
|
4 |
import numpy as np
|
|
|
5 |
import torch
|
6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
7 |
-
from diffusers import (
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
12 |
from PIL import Image
|
13 |
from torch.nn import Linear
|
14 |
from tqdm import gui
|
@@ -18,156 +24,127 @@ import internals.util.image as ImageUtil
|
|
18 |
from external.midas import apply_midas
|
19 |
from internals.data.result import Result
|
20 |
from internals.pipelines.commons import AbstractPipeline
|
21 |
-
from internals.pipelines.tileUpscalePipeline import
|
22 |
-
StableDiffusionControlNetImg2ImgPipeline
|
|
|
23 |
from internals.util.cache import clear_cuda_and_gc
|
24 |
from internals.util.commons import download_image
|
25 |
-
from internals.util.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
class ControlNet(AbstractPipeline):
|
29 |
__current_task_name = ""
|
30 |
__loaded = False
|
31 |
|
32 |
-
|
33 |
-
"Should not be called externally"
|
34 |
-
if self.__loaded:
|
35 |
-
return
|
36 |
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
controlnet=self.controlnet,
|
44 |
-
torch_dtype=torch.float16,
|
45 |
-
use_auth_token=get_hf_token(),
|
46 |
-
cache_dir=get_hf_cache_dir(),
|
47 |
-
).to("cuda")
|
48 |
-
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
49 |
-
pipe.enable_model_cpu_offload()
|
50 |
-
pipe.enable_xformers_memory_efficient_attention()
|
51 |
-
self.pipe = pipe
|
52 |
-
|
53 |
-
# controlnet pipeline for canny and pose
|
54 |
-
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
|
55 |
-
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
|
56 |
-
pipe2.enable_xformers_memory_efficient_attention()
|
57 |
-
self.pipe2 = pipe2
|
58 |
-
|
59 |
-
self.__loaded = True
|
60 |
-
|
61 |
-
def load_canny(self):
|
62 |
-
if self.__current_task_name == "canny":
|
63 |
return
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
torch_dtype=torch.float16,
|
67 |
cache_dir=get_hf_cache_dir(),
|
68 |
).to("cuda")
|
69 |
-
self.__current_task_name =
|
70 |
-
self.controlnet =
|
71 |
|
72 |
-
self.
|
73 |
|
74 |
if hasattr(self, "pipe"):
|
75 |
-
self.pipe.controlnet =
|
76 |
if hasattr(self, "pipe2"):
|
77 |
-
self.pipe2.controlnet =
|
78 |
clear_cuda_and_gc()
|
79 |
|
80 |
-
def
|
81 |
-
|
|
|
82 |
return
|
83 |
-
pose = ControlNetModel.from_pretrained(
|
84 |
-
"lllyasviel/control_v11p_sd15_openpose",
|
85 |
-
torch_dtype=torch.float16,
|
86 |
-
cache_dir=get_hf_cache_dir(),
|
87 |
-
).to("cuda")
|
88 |
-
# lineart = ControlNetModel.from_pretrained(
|
89 |
-
# "ControlNet-1-1-preview/control_v11p_sd15_lineart",
|
90 |
-
# torch_dtype=torch.float16,
|
91 |
-
# cache_dir=get_hf_cache_dir(),
|
92 |
-
# ).to("cuda")
|
93 |
-
self.__current_task_name = "pose"
|
94 |
-
self.controlnet = MultiControlNetModel([pose]).to("cuda")
|
95 |
-
|
96 |
-
self.load()
|
97 |
|
98 |
-
if hasattr(self, "
|
99 |
-
self.
|
100 |
-
if hasattr(self, "pipe2"):
|
101 |
-
self.pipe2.controlnet = self.controlnet
|
102 |
-
clear_cuda_and_gc()
|
103 |
-
|
104 |
-
def load_tile_upscaler(self):
|
105 |
-
if self.__current_task_name == "tile_upscaler":
|
106 |
-
return
|
107 |
-
tile_upscaler = ControlNetModel.from_pretrained(
|
108 |
-
"lllyasviel/control_v11f1e_sd15_tile",
|
109 |
-
torch_dtype=torch.float16,
|
110 |
-
cache_dir=get_hf_cache_dir(),
|
111 |
-
).to("cuda")
|
112 |
-
self.__current_task_name = "tile_upscaler"
|
113 |
-
self.controlnet = tile_upscaler
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
|
118 |
-
self.pipe.controlnet = tile_upscaler
|
119 |
-
if hasattr(self, "pipe2"):
|
120 |
-
self.pipe2.controlnet = tile_upscaler
|
121 |
-
clear_cuda_and_gc()
|
122 |
|
123 |
-
def
|
|
|
|
|
|
|
|
|
124 |
if self.__current_task_name == "scribble":
|
125 |
-
return
|
126 |
-
scribble = ControlNetModel.from_pretrained(
|
127 |
-
"lllyasviel/control_v11p_sd15_scribble",
|
128 |
-
torch_dtype=torch.float16,
|
129 |
-
cache_dir=get_hf_cache_dir(),
|
130 |
-
).to("cuda")
|
131 |
-
self.__current_task_name = "scribble"
|
132 |
-
self.controlnet = scribble
|
133 |
-
|
134 |
-
self.load()
|
135 |
-
|
136 |
-
if hasattr(self, "pipe"):
|
137 |
-
self.pipe.controlnet = scribble
|
138 |
-
if hasattr(self, "pipe2"):
|
139 |
-
self.pipe2.controlnet = scribble
|
140 |
-
clear_cuda_and_gc()
|
141 |
-
|
142 |
-
def load_linearart(self):
|
143 |
if self.__current_task_name == "linearart":
|
144 |
-
return
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
cache_dir=get_hf_cache_dir(),
|
149 |
-
).to("cuda")
|
150 |
-
self.__current_task_name = "linearart"
|
151 |
-
self.controlnet = linearart
|
152 |
-
|
153 |
-
self.load()
|
154 |
-
|
155 |
-
if hasattr(self, "pipe"):
|
156 |
-
self.pipe.controlnet = linearart
|
157 |
-
if hasattr(self, "pipe2"):
|
158 |
-
self.pipe2.controlnet = linearart
|
159 |
-
clear_cuda_and_gc()
|
160 |
-
|
161 |
-
def cleanup(self):
|
162 |
-
if hasattr(self, "pipe"):
|
163 |
-
self.pipe.controlnet = None
|
164 |
-
if hasattr(self, "pipe2"):
|
165 |
-
self.pipe2.controlnet = None
|
166 |
-
self.controlnet = None
|
167 |
-
del self.controlnet
|
168 |
-
self.__current_task_name = ""
|
169 |
-
|
170 |
-
clear_cuda_and_gc()
|
171 |
|
172 |
@torch.inference_mode()
|
173 |
def process_canny(
|
@@ -228,7 +205,6 @@ class ControlNet(AbstractPipeline):
|
|
228 |
guidance_scale=guidance_scale,
|
229 |
height=height,
|
230 |
width=width,
|
231 |
-
controlnet_conditioning_scale=[1.0],
|
232 |
)
|
233 |
return Result.from_result(result)
|
234 |
|
@@ -333,6 +309,17 @@ class ControlNet(AbstractPipeline):
|
|
333 |
)
|
334 |
return Result.from_result(result)
|
335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
def detect_pose(self, imageUrl: str) -> Image.Image:
|
337 |
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
338 |
image = download_image(imageUrl)
|
@@ -381,3 +368,18 @@ class ControlNet(AbstractPipeline):
|
|
381 |
W = int(round(W / 64.0)) * 64
|
382 |
img = input_image.resize((W, H), resample=Image.LANCZOS)
|
383 |
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal, Union
|
2 |
|
3 |
import cv2
|
4 |
import numpy as np
|
5 |
+
from pydash import has
|
6 |
import torch
|
7 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
8 |
+
from diffusers import (
|
9 |
+
ControlNetModel,
|
10 |
+
DiffusionPipeline,
|
11 |
+
StableDiffusionControlNetPipeline,
|
12 |
+
UniPCMultistepScheduler,
|
13 |
+
StableDiffusionXLControlNetPipeline,
|
14 |
+
)
|
15 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
|
16 |
+
MultiControlNetModel,
|
17 |
+
)
|
18 |
from PIL import Image
|
19 |
from torch.nn import Linear
|
20 |
from tqdm import gui
|
|
|
24 |
from external.midas import apply_midas
|
25 |
from internals.data.result import Result
|
26 |
from internals.pipelines.commons import AbstractPipeline
|
27 |
+
from internals.pipelines.tileUpscalePipeline import (
|
28 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
29 |
+
)
|
30 |
from internals.util.cache import clear_cuda_and_gc
|
31 |
from internals.util.commons import download_image
|
32 |
+
from internals.util.config import (
|
33 |
+
get_hf_cache_dir,
|
34 |
+
get_hf_token,
|
35 |
+
get_model_dir,
|
36 |
+
get_is_sdxl,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
|
41 |
|
42 |
|
43 |
class ControlNet(AbstractPipeline):
|
44 |
__current_task_name = ""
|
45 |
__loaded = False
|
46 |
|
47 |
+
__pipeline: AbstractPipeline
|
|
|
|
|
|
|
48 |
|
49 |
+
def init(self, pipeline: AbstractPipeline):
|
50 |
+
self.__pipeline = pipeline
|
51 |
|
52 |
+
def load_model(self, task_name: CONTROLNET_TYPES):
|
53 |
+
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
|
54 |
+
if self.__current_task_name == task_name:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
return
|
56 |
+
model = config[task_name]
|
57 |
+
if not model:
|
58 |
+
raise Exception(f"ControlNet is not supported for {task_name}")
|
59 |
+
while model in list(config.keys()):
|
60 |
+
task_name = config[model] # pyright: ignore
|
61 |
+
model = config[task_name]
|
62 |
+
|
63 |
+
controlnet = ControlNetModel.from_pretrained(
|
64 |
+
model,
|
65 |
torch_dtype=torch.float16,
|
66 |
cache_dir=get_hf_cache_dir(),
|
67 |
).to("cuda")
|
68 |
+
self.__current_task_name = task_name
|
69 |
+
self.controlnet = controlnet
|
70 |
|
71 |
+
self.__load()
|
72 |
|
73 |
if hasattr(self, "pipe"):
|
74 |
+
self.pipe.controlnet = controlnet
|
75 |
if hasattr(self, "pipe2"):
|
76 |
+
self.pipe2.controlnet = controlnet
|
77 |
clear_cuda_and_gc()
|
78 |
|
79 |
+
def __load(self):
|
80 |
+
"Should not be called externally"
|
81 |
+
if self.__loaded:
|
82 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
if not hasattr(self, "controlnet"):
|
85 |
+
self.load_model("pose")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
# controlnet pipeline for tile upscaler
|
88 |
+
if get_is_sdxl():
|
89 |
+
print("Warning: Tile upscale is not supported on SDXL")
|
90 |
+
|
91 |
+
if self.__pipeline:
|
92 |
+
pipe = StableDiffusionXLControlNetPipeline(
|
93 |
+
controlnet=self.controlnet, **self.__pipeline.pipe.components
|
94 |
+
).to("cuda")
|
95 |
+
else:
|
96 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
97 |
+
get_model_dir(),
|
98 |
+
controlnet=self.controlnet,
|
99 |
+
torch_dtype=torch.float16,
|
100 |
+
use_auth_token=get_hf_token(),
|
101 |
+
cache_dir=get_hf_cache_dir(),
|
102 |
+
use_safetensors=True,
|
103 |
+
).to("cuda")
|
104 |
+
pipe.enable_vae_tiling()
|
105 |
+
pipe.enable_vae_slicing()
|
106 |
+
pipe.enable_xformers_memory_efficient_attention()
|
107 |
+
self.pipe2 = pipe
|
108 |
+
else:
|
109 |
+
if hasattr(self, "__pipeline"):
|
110 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline(
|
111 |
+
controlnet=self.controlnet, **self.__pipeline.pipe.components
|
112 |
+
).to("cuda")
|
113 |
+
else:
|
114 |
+
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
115 |
+
get_model_dir(),
|
116 |
+
controlnet=self.controlnet,
|
117 |
+
torch_dtype=torch.float16,
|
118 |
+
use_auth_token=get_hf_token(),
|
119 |
+
cache_dir=get_hf_cache_dir(),
|
120 |
+
).to("cuda")
|
121 |
+
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
122 |
+
pipe.enable_model_cpu_offload()
|
123 |
+
pipe.enable_xformers_memory_efficient_attention()
|
124 |
+
self.pipe = pipe
|
125 |
+
|
126 |
+
# controlnet pipeline for canny and pose
|
127 |
+
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
|
128 |
+
pipe2.scheduler = UniPCMultistepScheduler.from_config(
|
129 |
+
pipe2.scheduler.config
|
130 |
+
)
|
131 |
+
pipe2.enable_xformers_memory_efficient_attention()
|
132 |
+
self.pipe2 = pipe2
|
133 |
|
134 |
+
self.__loaded = True
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
def process(self, **kwargs):
|
137 |
+
if self.__current_task_name == "pose":
|
138 |
+
return self.process_pose(**kwargs)
|
139 |
+
if self.__current_task_name == "canny":
|
140 |
+
return self.process_canny(**kwargs)
|
141 |
if self.__current_task_name == "scribble":
|
142 |
+
return self.process_scribble(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
if self.__current_task_name == "linearart":
|
144 |
+
return self.process_linearart(**kwargs)
|
145 |
+
if self.__current_task_name == "tile_upscaler":
|
146 |
+
return self.process_tile_upscaler(**kwargs)
|
147 |
+
raise Exception("ControlNet is not loaded with any model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
@torch.inference_mode()
|
150 |
def process_canny(
|
|
|
205 |
guidance_scale=guidance_scale,
|
206 |
height=height,
|
207 |
width=width,
|
|
|
208 |
)
|
209 |
return Result.from_result(result)
|
210 |
|
|
|
309 |
)
|
310 |
return Result.from_result(result)
|
311 |
|
312 |
+
def cleanup(self):
|
313 |
+
if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"):
|
314 |
+
del self.pipe.controlnet
|
315 |
+
if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"):
|
316 |
+
del self.pipe2.controlnet
|
317 |
+
if hasattr(self, "controlnet"):
|
318 |
+
del self.controlnet
|
319 |
+
self.__current_task_name = ""
|
320 |
+
|
321 |
+
clear_cuda_and_gc()
|
322 |
+
|
323 |
def detect_pose(self, imageUrl: str) -> Image.Image:
|
324 |
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
325 |
image = download_image(imageUrl)
|
|
|
368 |
W = int(round(W / 64.0)) * 64
|
369 |
img = input_image.resize((W, H), resample=Image.LANCZOS)
|
370 |
return img
|
371 |
+
|
372 |
+
__model_normal = {
|
373 |
+
"pose": "lllyasviel/control_v11p_sd15_openpose",
|
374 |
+
"canny": "lllyasviel/control_v11p_sd15_canny",
|
375 |
+
"linearart": "lllyasviel/control_v11p_sd15_lineart",
|
376 |
+
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
377 |
+
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
|
378 |
+
}
|
379 |
+
__model_sdxl = {
|
380 |
+
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
|
381 |
+
"canny": "diffusers/controlnet-canny-sdxl-1.0",
|
382 |
+
"linearart": "canny",
|
383 |
+
"scribble": "canny",
|
384 |
+
"tile_upscaler": None,
|
385 |
+
}
|
internals/pipelines/high_res.py
CHANGED
@@ -42,7 +42,7 @@ class HighRes(AbstractPipeline):
|
|
42 |
|
43 |
@staticmethod
|
44 |
def get_intermediate_dimension(target_width: int, target_height: int):
|
45 |
-
def_size =
|
46 |
|
47 |
desired_pixel_count = def_size * def_size
|
48 |
actual_pixel_count = target_width * target_height
|
|
|
42 |
|
43 |
@staticmethod
|
44 |
def get_intermediate_dimension(target_width: int, target_height: int):
|
45 |
+
def_size = 1024
|
46 |
|
47 |
desired_pixel_count = def_size * def_size
|
48 |
actual_pixel_count = target_width * target_height
|
internals/pipelines/inpainter.py
CHANGED
@@ -1,38 +1,74 @@
|
|
1 |
from typing import List, Union
|
2 |
|
3 |
import torch
|
4 |
-
from diffusers import StableDiffusionInpaintPipeline
|
5 |
|
6 |
from internals.pipelines.commons import AbstractPipeline
|
7 |
from internals.util.commons import disable_safety_checker, download_image
|
8 |
-
from internals.util.config import (
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
class InPainter(AbstractPipeline):
|
13 |
__loaded = False
|
14 |
|
|
|
|
|
|
|
15 |
def load(self):
|
16 |
if self.__loaded:
|
17 |
return
|
18 |
|
19 |
-
self
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
disable_safety_checker(self.pipe)
|
27 |
|
|
|
|
|
28 |
self.__loaded = True
|
29 |
|
30 |
def create(self, pipeline: AbstractPipeline):
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
disable_safety_checker(self.pipe)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
@torch.inference_mode()
|
37 |
def process(
|
38 |
self,
|
|
|
1 |
from typing import List, Union
|
2 |
|
3 |
import torch
|
4 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
|
5 |
|
6 |
from internals.pipelines.commons import AbstractPipeline
|
7 |
from internals.util.commons import disable_safety_checker, download_image
|
8 |
+
from internals.util.config import (
|
9 |
+
get_hf_cache_dir,
|
10 |
+
get_hf_token,
|
11 |
+
get_is_sdxl,
|
12 |
+
get_inpaint_model_path,
|
13 |
+
get_model_dir,
|
14 |
+
)
|
15 |
|
16 |
|
17 |
class InPainter(AbstractPipeline):
|
18 |
__loaded = False
|
19 |
|
20 |
+
def init(self, pipeline: AbstractPipeline):
|
21 |
+
self.__base = pipeline
|
22 |
+
|
23 |
def load(self):
|
24 |
if self.__loaded:
|
25 |
return
|
26 |
|
27 |
+
if hasattr(self, "__base") and get_inpaint_model_path() == get_model_dir():
|
28 |
+
self.create(self.__base)
|
29 |
+
self.__loaded = True
|
30 |
+
return
|
31 |
+
|
32 |
+
if get_is_sdxl():
|
33 |
+
self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
|
34 |
+
get_inpaint_model_path(),
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
cache_dir=get_hf_cache_dir(),
|
37 |
+
use_auth_token=get_hf_token(),
|
38 |
+
).to("cuda")
|
39 |
+
else:
|
40 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
41 |
+
get_inpaint_model_path(),
|
42 |
+
torch_dtype=torch.float16,
|
43 |
+
cache_dir=get_hf_cache_dir(),
|
44 |
+
use_auth_token=get_hf_token(),
|
45 |
+
).to("cuda")
|
46 |
|
47 |
disable_safety_checker(self.pipe)
|
48 |
|
49 |
+
self.__patch()
|
50 |
+
|
51 |
self.__loaded = True
|
52 |
|
53 |
def create(self, pipeline: AbstractPipeline):
|
54 |
+
if get_is_sdxl():
|
55 |
+
self.pipe = StableDiffusionXLInpaintPipeline(**pipeline.pipe.components).to(
|
56 |
+
"cuda"
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
60 |
+
"cuda"
|
61 |
+
)
|
62 |
disable_safety_checker(self.pipe)
|
63 |
|
64 |
+
self.__patch()
|
65 |
+
|
66 |
+
def __patch(self):
|
67 |
+
if get_is_sdxl():
|
68 |
+
self.pipe.enable_vae_tiling()
|
69 |
+
self.pipe.enable_vae_slicing()
|
70 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
71 |
+
|
72 |
@torch.inference_mode()
|
73 |
def process(
|
74 |
self,
|
internals/pipelines/remove_background.py
CHANGED
@@ -1,15 +1,20 @@
|
|
1 |
import io
|
2 |
from pathlib import Path
|
3 |
from typing import Union
|
|
|
|
|
4 |
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from PIL import Image
|
8 |
from rembg import remove
|
|
|
9 |
|
10 |
import internals.util.image as ImageUtil
|
11 |
from carvekit.api.high import HiInterface
|
12 |
from internals.util.commons import download_image, read_url
|
|
|
|
|
13 |
|
14 |
|
15 |
class RemoveBackground:
|
@@ -23,6 +28,11 @@ class RemoveBackground:
|
|
23 |
|
24 |
class RemoveBackgroundV2:
|
25 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
26 |
self.interface = HiInterface(
|
27 |
object_type="object", # Can be "object" or "hairs-like".
|
28 |
batch_size_seg=5,
|
@@ -36,16 +46,51 @@ class RemoveBackgroundV2:
|
|
36 |
fp16=False,
|
37 |
)
|
38 |
|
39 |
-
def remove(
|
40 |
-
|
|
|
41 |
if type(image) is str:
|
42 |
image = download_image(image)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
1 |
import io
|
2 |
from pathlib import Path
|
3 |
from typing import Union
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
from PIL import Image
|
10 |
from rembg import remove
|
11 |
+
from internals.data.task import ModelType
|
12 |
|
13 |
import internals.util.image as ImageUtil
|
14 |
from carvekit.api.high import HiInterface
|
15 |
from internals.util.commons import download_image, read_url
|
16 |
+
import onnxruntime as rt
|
17 |
+
import huggingface_hub
|
18 |
|
19 |
|
20 |
class RemoveBackground:
|
|
|
28 |
|
29 |
class RemoveBackgroundV2:
|
30 |
def __init__(self):
|
31 |
+
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
|
32 |
+
self.anime_rembg = rt.InferenceSession(
|
33 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
34 |
+
)
|
35 |
+
|
36 |
self.interface = HiInterface(
|
37 |
object_type="object", # Can be "object" or "hairs-like".
|
38 |
batch_size_seg=5,
|
|
|
46 |
fp16=False,
|
47 |
)
|
48 |
|
49 |
+
def remove(
|
50 |
+
self, image: Union[str, Image.Image], model_type: ModelType = ModelType.REAL
|
51 |
+
) -> Image.Image:
|
52 |
if type(image) is str:
|
53 |
image = download_image(image)
|
54 |
|
55 |
+
if model_type == ModelType.ANIME or model_type == ModelType.COMIC:
|
56 |
+
print("Using Anime Background remover")
|
57 |
+
_, img = self.__rmbg_fn(np.array(image))
|
58 |
+
|
59 |
+
return Image.fromarray(img)
|
60 |
+
else:
|
61 |
+
print("Using Real Background remover")
|
62 |
+
img_path = Path.home() / ".cache" / "rm_bg.png"
|
63 |
+
|
64 |
+
w, h = image.size
|
65 |
+
if max(w, h) > 1536:
|
66 |
+
image = ImageUtil.resize_image(image, dimension=1024)
|
67 |
+
|
68 |
+
image.save(img_path)
|
69 |
+
images_without_background = self.interface([img_path])
|
70 |
+
out = images_without_background[0]
|
71 |
+
return out
|
72 |
+
|
73 |
+
def __get_mask(self, img, s=1024):
|
74 |
+
img = (img / 255).astype(np.float32)
|
75 |
+
h, w = h0, w0 = img.shape[:-1]
|
76 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
77 |
+
ph, pw = s - h, s - w
|
78 |
+
img_input = np.zeros([s, s, 3], dtype=np.float32)
|
79 |
+
img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(
|
80 |
+
img, (w, h)
|
81 |
+
)
|
82 |
+
img_input = np.transpose(img_input, (2, 0, 1))
|
83 |
+
img_input = img_input[np.newaxis, :]
|
84 |
+
mask = self.anime_rembg.run(None, {"img": img_input})[0][0]
|
85 |
+
mask = np.transpose(mask, (1, 2, 0))
|
86 |
+
mask = mask[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
|
87 |
+
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
|
88 |
+
return mask
|
89 |
|
90 |
+
def __rmbg_fn(self, img):
|
91 |
+
mask = self.__get_mask(img)
|
92 |
+
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
|
93 |
+
mask = (mask * 255).astype(np.uint8)
|
94 |
+
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
|
95 |
+
mask = mask.repeat(3, axis=2)
|
96 |
+
return mask, img
|
internals/pipelines/replace_background.py
CHANGED
@@ -3,10 +3,14 @@ from typing import List, Optional, Union
|
|
3 |
|
4 |
import torch
|
5 |
from cv2 import inpaint
|
6 |
-
from diffusers import (
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
from PIL import Image, ImageFilter, ImageOps
|
|
|
10 |
|
11 |
import internals.util.image as ImageUtil
|
12 |
from internals.data.result import Result
|
@@ -17,8 +21,12 @@ from internals.pipelines.inpainter import InPainter
|
|
17 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
18 |
from internals.pipelines.upscaler import Upscaler
|
19 |
from internals.util.commons import download_image
|
20 |
-
from internals.util.config import (
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
class ReplaceBackground(AbstractPipeline):
|
@@ -52,7 +60,8 @@ class ReplaceBackground(AbstractPipeline):
|
|
52 |
cache_dir=get_hf_cache_dir(),
|
53 |
use_auth_token=get_hf_token(),
|
54 |
)
|
55 |
-
pipe.
|
|
|
56 |
pipe.to("cuda")
|
57 |
|
58 |
self.pipe = pipe
|
@@ -87,6 +96,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
87 |
seed: int,
|
88 |
steps: int,
|
89 |
apply_high_res: bool = False,
|
|
|
90 |
):
|
91 |
# image = Image.open("original.png")
|
92 |
if type(image) is str:
|
@@ -98,7 +108,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
98 |
image = image.convert("RGB")
|
99 |
if max(image.size) > 1024:
|
100 |
image = ImageUtil.resize_image(image, dimension=1024)
|
101 |
-
image = self.remove_background.remove(image)
|
102 |
|
103 |
width = int(width)
|
104 |
height = int(height)
|
|
|
3 |
|
4 |
import torch
|
5 |
from cv2 import inpaint
|
6 |
+
from diffusers import (
|
7 |
+
ControlNetModel,
|
8 |
+
StableDiffusionControlNetInpaintPipeline,
|
9 |
+
StableDiffusionInpaintPipeline,
|
10 |
+
UniPCMultistepScheduler,
|
11 |
+
)
|
12 |
from PIL import Image, ImageFilter, ImageOps
|
13 |
+
from internals.data.task import ModelType
|
14 |
|
15 |
import internals.util.image as ImageUtil
|
16 |
from internals.data.result import Result
|
|
|
21 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
22 |
from internals.pipelines.upscaler import Upscaler
|
23 |
from internals.util.commons import download_image
|
24 |
+
from internals.util.config import (
|
25 |
+
get_hf_cache_dir,
|
26 |
+
get_hf_token,
|
27 |
+
get_inpaint_model_path,
|
28 |
+
get_model_dir,
|
29 |
+
)
|
30 |
|
31 |
|
32 |
class ReplaceBackground(AbstractPipeline):
|
|
|
60 |
cache_dir=get_hf_cache_dir(),
|
61 |
use_auth_token=get_hf_token(),
|
62 |
)
|
63 |
+
pipe.enable_xformers_memory_efficient_attention()
|
64 |
+
pipe.enable_vae_slicing()
|
65 |
pipe.to("cuda")
|
66 |
|
67 |
self.pipe = pipe
|
|
|
96 |
seed: int,
|
97 |
steps: int,
|
98 |
apply_high_res: bool = False,
|
99 |
+
model_type: ModelType = ModelType.REAL,
|
100 |
):
|
101 |
# image = Image.open("original.png")
|
102 |
if type(image) is str:
|
|
|
108 |
image = image.convert("RGB")
|
109 |
if max(image.size) > 1024:
|
110 |
image = ImageUtil.resize_image(image, dimension=1024)
|
111 |
+
image = self.remove_background.remove(image, model_type=model_type)
|
112 |
|
113 |
width = int(width)
|
114 |
height = int(height)
|
internals/pipelines/twoStepPipeline.py
CHANGED
@@ -12,7 +12,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
12 |
|
13 |
class two_step_pipeline(StableDiffusionPipeline):
|
14 |
@torch.no_grad()
|
15 |
-
def
|
16 |
self,
|
17 |
prompt: Union[str, List[str]] = None,
|
18 |
modified_prompts: Union[str, List[str]] = None,
|
|
|
12 |
|
13 |
class two_step_pipeline(StableDiffusionPipeline):
|
14 |
@torch.no_grad()
|
15 |
+
def __call__(
|
16 |
self,
|
17 |
prompt: Union[str, List[str]] = None,
|
18 |
modified_prompts: Union[str, List[str]] = None,
|
internals/util/cache.py
CHANGED
@@ -1,15 +1,25 @@
|
|
1 |
import gc
|
2 |
-
|
|
|
3 |
import torch
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
def clear_cuda_and_gc():
|
7 |
-
|
|
|
8 |
clear_gc()
|
|
|
|
|
9 |
|
10 |
|
11 |
def clear_cuda():
|
12 |
-
torch.
|
|
|
13 |
|
14 |
|
15 |
def clear_gc():
|
|
|
1 |
import gc
|
2 |
+
import os
|
3 |
+
import psutil
|
4 |
import torch
|
5 |
|
6 |
|
7 |
+
def print_memory_usage():
|
8 |
+
process = psutil.Process(os.getpid())
|
9 |
+
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
|
10 |
+
|
11 |
+
|
12 |
def clear_cuda_and_gc():
|
13 |
+
print_memory_usage()
|
14 |
+
print("Clearing cuda and gc")
|
15 |
clear_gc()
|
16 |
+
clear_cuda()
|
17 |
+
print_memory_usage()
|
18 |
|
19 |
|
20 |
def clear_cuda():
|
21 |
+
with torch.no_grad():
|
22 |
+
torch.cuda.empty_cache()
|
23 |
|
24 |
|
25 |
def clear_gc():
|
internals/util/commons.py
CHANGED
@@ -150,9 +150,9 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path):
|
|
150 |
return image_url
|
151 |
|
152 |
|
153 |
-
def download_image(url) -> Image.Image:
|
154 |
response = requests.get(url)
|
155 |
-
return Image.open(BytesIO(response.content)).convert(
|
156 |
|
157 |
|
158 |
def download_file(url, out_path: Path):
|
|
|
150 |
return image_url
|
151 |
|
152 |
|
153 |
+
def download_image(url, mode="RGB") -> Image.Image:
|
154 |
response = requests.get(url)
|
155 |
+
return Image.open(BytesIO(response.content)).convert(mode)
|
156 |
|
157 |
|
158 |
def download_file(url, out_path: Path):
|
internals/util/config.py
CHANGED
@@ -61,6 +61,11 @@ def get_inpaint_model_path():
|
|
61 |
return model_config.base_inpaint_model_path # pyright: ignore
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
64 |
def get_root_dir():
|
65 |
global root_dir
|
66 |
return root_dir
|
|
|
61 |
return model_config.base_inpaint_model_path # pyright: ignore
|
62 |
|
63 |
|
64 |
+
def get_is_sdxl():
|
65 |
+
global model_config
|
66 |
+
return model_config.is_sdxl # pyright: ignore
|
67 |
+
|
68 |
+
|
69 |
def get_root_dir():
|
70 |
global root_dir
|
71 |
return root_dir
|
internals/util/lora_style.py
CHANGED
@@ -10,6 +10,7 @@ from lora_diffusion import patch_pipe, tune_lora_scale
|
|
10 |
from pydash import chain
|
11 |
|
12 |
from internals.data.dataAccessor import getStyles
|
|
|
13 |
from internals.util.commons import download_file
|
14 |
|
15 |
|
@@ -112,6 +113,10 @@ class LoraStyle:
|
|
112 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
113 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
114 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
|
|
|
|
|
|
|
|
115 |
if key in self.__styles:
|
116 |
style = self.__styles[key]
|
117 |
if style["type"] == "diffuser":
|
|
|
10 |
from pydash import chain
|
11 |
|
12 |
from internals.data.dataAccessor import getStyles
|
13 |
+
from internals.util.config import get_is_sdxl
|
14 |
from internals.util.commons import download_file
|
15 |
|
16 |
|
|
|
113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
116 |
+
if get_is_sdxl():
|
117 |
+
print("Warning: Lora is not supported on SDXL")
|
118 |
+
return self.EmptyLoraPatcher(pipe)
|
119 |
+
|
120 |
if key in self.__styles:
|
121 |
style = self.__styles[key]
|
122 |
if style["type"] == "diffuser":
|
internals/util/model_loader.py
CHANGED
@@ -14,6 +14,7 @@ from tqdm import tqdm
|
|
14 |
class ModelConfig:
|
15 |
base_model_path: str
|
16 |
base_inpaint_model_path: str
|
|
|
17 |
|
18 |
|
19 |
def load_model_from_config(path):
|
@@ -23,9 +24,11 @@ def load_model_from_config(path):
|
|
23 |
config = json.loads(f.read())
|
24 |
model_path = config.get("model_path", path)
|
25 |
inpaint_model_path = config.get("inpaint_model_path", path)
|
|
|
26 |
|
27 |
m_config.base_model_path = model_path
|
28 |
m_config.base_inpaint_model_path = inpaint_model_path
|
|
|
29 |
|
30 |
#
|
31 |
# if config.get("model_type") == "huggingface":
|
|
|
14 |
class ModelConfig:
|
15 |
base_model_path: str
|
16 |
base_inpaint_model_path: str
|
17 |
+
is_sdxl: bool = False
|
18 |
|
19 |
|
20 |
def load_model_from_config(path):
|
|
|
24 |
config = json.loads(f.read())
|
25 |
model_path = config.get("model_path", path)
|
26 |
inpaint_model_path = config.get("inpaint_model_path", path)
|
27 |
+
is_sdxl = config.get("is_sdxl", False)
|
28 |
|
29 |
m_config.base_model_path = model_path
|
30 |
m_config.base_inpaint_model_path = inpaint_model_path
|
31 |
+
m_config.is_sdxl = is_sdxl
|
32 |
|
33 |
#
|
34 |
# if config.get("model_type") == "huggingface":
|
pyproject.toml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
[tool.pyright]
|
2 |
-
venvPath = "
|
3 |
venv = "env"
|
4 |
exclude = ["env"]
|
|
|
1 |
[tool.pyright]
|
2 |
+
venvPath = "."
|
3 |
venv = "env"
|
4 |
exclude = ["env"]
|
requirements.txt
CHANGED
@@ -15,6 +15,7 @@ realesrgan==0.3.0
|
|
15 |
compel==1.0.4
|
16 |
scikit-image>=0.19.3
|
17 |
six==1.16.0
|
|
|
18 |
tifffile==2021.8.30
|
19 |
easydict==1.9.0
|
20 |
albumentations
|
@@ -32,10 +33,13 @@ xformers==0.0.21
|
|
32 |
scikit-image==0.19.3
|
33 |
omegaconf==2.3.0
|
34 |
webdataset==0.2.48
|
|
|
35 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
36 |
python-dateutil==2.8.2
|
37 |
PyYAML
|
38 |
invisible-watermark
|
39 |
torchvision==0.15.2
|
|
|
|
|
40 |
imgaug==0.4.0
|
41 |
tqdm==4.64.1
|
|
|
15 |
compel==1.0.4
|
16 |
scikit-image>=0.19.3
|
17 |
six==1.16.0
|
18 |
+
psutil
|
19 |
tifffile==2021.8.30
|
20 |
easydict==1.9.0
|
21 |
albumentations
|
|
|
33 |
scikit-image==0.19.3
|
34 |
omegaconf==2.3.0
|
35 |
webdataset==0.2.48
|
36 |
+
invisible-watermark
|
37 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
38 |
python-dateutil==2.8.2
|
39 |
PyYAML
|
40 |
invisible-watermark
|
41 |
torchvision==0.15.2
|
42 |
+
onnx
|
43 |
+
onnxruntime-gpu
|
44 |
imgaug==0.4.0
|
45 |
tqdm==4.64.1
|