Upload folder using huggingface_hub
Browse files- external/scripts/__init__.py +13 -0
- external/scripts/day_night_ip2p.py +63 -0
- inference.py +34 -0
- internals/data/dataAccessor.py +7 -3
- internals/data/task.py +5 -0
- internals/pipelines/controlnets.py +5 -1
- internals/pipelines/high_res.py +4 -1
- internals/util/__init__.py +7 -0
- internals/util/lora_style.py +1 -4
- internals/util/slack.py +6 -1
external/scripts/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
from internals.util import getcwd
|
5 |
+
|
6 |
+
path = os.path.join(getcwd(), "external/scripts")
|
7 |
+
|
8 |
+
__scripts__ = []
|
9 |
+
for name in os.listdir(path):
|
10 |
+
name = name.split("/")[-1].replace(".py", "")
|
11 |
+
imp = importlib.import_module(f"external.scripts.{name}")
|
12 |
+
if hasattr(imp, "Script") and imp not in __scripts__:
|
13 |
+
__scripts__.append(imp)
|
external/scripts/day_night_ip2p.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline
|
3 |
+
|
4 |
+
import internals.util.image as ImageUtil
|
5 |
+
from internals.data.dataAccessor import update_db
|
6 |
+
from internals.data.task import Task
|
7 |
+
from internals.util.cache import clear_cuda_and_gc
|
8 |
+
from internals.util.commons import download_image, upload_images
|
9 |
+
from internals.util.config import get_hf_token
|
10 |
+
from internals.util.slack import Slack
|
11 |
+
|
12 |
+
slack = Slack()
|
13 |
+
|
14 |
+
|
15 |
+
class Script:
|
16 |
+
def __init__(self, **kwargs):
|
17 |
+
self.__name__ = "day_night_ip2p"
|
18 |
+
|
19 |
+
@update_db
|
20 |
+
@slack.auto_send_alert
|
21 |
+
def __call__(self, task: Task, args: dict):
|
22 |
+
clear_cuda_and_gc()
|
23 |
+
|
24 |
+
model_id = args.get("model_id", None)
|
25 |
+
steps = args.get("steps", 50)
|
26 |
+
image_guidance_scale = args.get("image_guidance_scale", 1.5)
|
27 |
+
guidance_scale = args.get("guidance_scale", 7.5)
|
28 |
+
|
29 |
+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
30 |
+
model_id,
|
31 |
+
use_auth_token=get_hf_token(),
|
32 |
+
torch_dtype=torch.float16,
|
33 |
+
safety_checker=None,
|
34 |
+
).to("cuda")
|
35 |
+
pipe.enable_xformers_memory_efficient_attention()
|
36 |
+
|
37 |
+
prompt = ["convert to night", "convert to evening", "convert to midnight"]
|
38 |
+
image = download_image(task.get_imageUrl())
|
39 |
+
image = ImageUtil.resize_image(image, 1024)
|
40 |
+
|
41 |
+
images = []
|
42 |
+
for p in prompt:
|
43 |
+
print("Generating: ", p)
|
44 |
+
image = pipe.__call__(
|
45 |
+
prompt=p,
|
46 |
+
num_inference_steps=steps,
|
47 |
+
image=image,
|
48 |
+
guidance_scale=guidance_scale,
|
49 |
+
num_images_per_prompt=1,
|
50 |
+
image_guidance_scale=image_guidance_scale,
|
51 |
+
).images[0]
|
52 |
+
images.append(image)
|
53 |
+
|
54 |
+
generated_image_urls = upload_images(
|
55 |
+
images, "_" + self.__name__, task.get_taskId()
|
56 |
+
)
|
57 |
+
|
58 |
+
pipe = None
|
59 |
+
del pipe
|
60 |
+
|
61 |
+
clear_cuda_and_gc()
|
62 |
+
|
63 |
+
return {"generated_image_urls": generated_image_urls}
|
inference.py
CHANGED
@@ -2,7 +2,9 @@ import os
|
|
2 |
import traceback
|
3 |
from typing import List, Optional
|
4 |
|
|
|
5 |
import torch
|
|
|
6 |
|
7 |
import internals.util.prompt as prompt_util
|
8 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
@@ -54,6 +56,8 @@ safety_checker = SafetyChecker()
|
|
54 |
slack = Slack()
|
55 |
avatar = Avatar()
|
56 |
|
|
|
|
|
57 |
|
58 |
def get_patched_prompt(task: Task):
|
59 |
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
|
@@ -533,6 +537,32 @@ def replace_bg(task: Task):
|
|
533 |
}
|
534 |
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
def load_model_by_task(task: Task):
|
537 |
if not text2img_pipe.is_loaded():
|
538 |
text2img_pipe.load(get_model_dir())
|
@@ -587,6 +617,8 @@ def predict_fn(data, pipe):
|
|
587 |
task = Task(data)
|
588 |
print("task is ", data)
|
589 |
|
|
|
|
|
590 |
FailureHandler.handle(task)
|
591 |
|
592 |
try:
|
@@ -629,6 +661,8 @@ def predict_fn(data, pipe):
|
|
629 |
return linearart(task)
|
630 |
elif task_type == TaskType.REPLACE_BG:
|
631 |
return replace_bg(task)
|
|
|
|
|
632 |
elif task_type == TaskType.SYSTEM_CMD:
|
633 |
os.system(task.get_prompt())
|
634 |
else:
|
|
|
2 |
import traceback
|
3 |
from typing import List, Optional
|
4 |
|
5 |
+
import pydash as _
|
6 |
import torch
|
7 |
+
from numpy import who
|
8 |
|
9 |
import internals.util.prompt as prompt_util
|
10 |
from internals.data.dataAccessor import update_db, update_db_source_failed
|
|
|
56 |
slack = Slack()
|
57 |
avatar = Avatar()
|
58 |
|
59 |
+
custom_scripts: List = []
|
60 |
+
|
61 |
|
62 |
def get_patched_prompt(task: Task):
|
63 |
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
|
|
|
537 |
}
|
538 |
|
539 |
|
540 |
+
def custom_action(task: Task):
|
541 |
+
from external.scripts import __scripts__
|
542 |
+
|
543 |
+
global custom_scripts
|
544 |
+
kwargs = {
|
545 |
+
"CONTROLNET": controlnet,
|
546 |
+
"LORASTYLE": lora_style,
|
547 |
+
}
|
548 |
+
|
549 |
+
torch.manual_seed(task.get_seed())
|
550 |
+
|
551 |
+
for script in __scripts__:
|
552 |
+
script = script.Script(**kwargs)
|
553 |
+
existing_script = _.find(
|
554 |
+
custom_scripts, lambda x: x.__name__ == script.__name__
|
555 |
+
)
|
556 |
+
if existing_script:
|
557 |
+
script = existing_script
|
558 |
+
else:
|
559 |
+
custom_scripts.append(script)
|
560 |
+
|
561 |
+
data = task.get_action_data()
|
562 |
+
if data["name"] == script.__name__:
|
563 |
+
return script(task, data)
|
564 |
+
|
565 |
+
|
566 |
def load_model_by_task(task: Task):
|
567 |
if not text2img_pipe.is_loaded():
|
568 |
text2img_pipe.load(get_model_dir())
|
|
|
617 |
task = Task(data)
|
618 |
print("task is ", data)
|
619 |
|
620 |
+
clear_cuda_and_gc()
|
621 |
+
|
622 |
FailureHandler.handle(task)
|
623 |
|
624 |
try:
|
|
|
661 |
return linearart(task)
|
662 |
elif task_type == TaskType.REPLACE_BG:
|
663 |
return replace_bg(task)
|
664 |
+
elif task_type == TaskType.CUSTOM_ACTION:
|
665 |
+
return custom_action(task)
|
666 |
elif task_type == TaskType.SYSTEM_CMD:
|
667 |
os.system(task.get_prompt())
|
668 |
else:
|
internals/data/dataAccessor.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import traceback
|
2 |
from typing import Dict, List, Optional
|
3 |
|
4 |
-
from requests.adapters import Retry, HTTPAdapter
|
5 |
import requests
|
6 |
from pydash import includes
|
|
|
7 |
|
8 |
from internals.data.task import Task
|
9 |
from internals.util.config import api_endpoint, api_headers
|
@@ -104,9 +104,13 @@ def update_db_source_failed(sourceId, userId):
|
|
104 |
|
105 |
def update_db(func):
|
106 |
def caller(*args, **kwargs):
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
raise Exception("First argument must be a Task object")
|
109 |
-
task = args[0]
|
110 |
try:
|
111 |
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
112 |
rargs = func(*args, **kwargs)
|
|
|
1 |
import traceback
|
2 |
from typing import Dict, List, Optional
|
3 |
|
|
|
4 |
import requests
|
5 |
from pydash import includes
|
6 |
+
from requests.adapters import HTTPAdapter, Retry
|
7 |
|
8 |
from internals.data.task import Task
|
9 |
from internals.util.config import api_endpoint, api_headers
|
|
|
104 |
|
105 |
def update_db(func):
|
106 |
def caller(*args, **kwargs):
|
107 |
+
task = None
|
108 |
+
for arg in args:
|
109 |
+
if type(arg) is Task:
|
110 |
+
task = arg
|
111 |
+
break
|
112 |
+
if task is None:
|
113 |
raise Exception("First argument must be a Task object")
|
|
|
114 |
try:
|
115 |
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
116 |
rargs = func(*args, **kwargs)
|
internals/data/task.py
CHANGED
@@ -18,6 +18,7 @@ class TaskType(Enum):
|
|
18 |
SCRIBBLE = "SCRIBBLE"
|
19 |
LINEARART = "LINEARART"
|
20 |
REPLACE_BG = "REPLACE_BG"
|
|
|
21 |
SYSTEM_CMD = "SYSTEM_CMD"
|
22 |
|
23 |
|
@@ -148,6 +149,10 @@ class Task:
|
|
148 |
def get_base_dimension(self):
|
149 |
return self.__data.get("base_dimension", None)
|
150 |
|
|
|
|
|
|
|
|
|
151 |
def get_raw(self) -> dict:
|
152 |
return self.__data.copy()
|
153 |
|
|
|
18 |
SCRIBBLE = "SCRIBBLE"
|
19 |
LINEARART = "LINEARART"
|
20 |
REPLACE_BG = "REPLACE_BG"
|
21 |
+
CUSTOM_ACTION = "CUSTOM_ACTION"
|
22 |
SYSTEM_CMD = "SYSTEM_CMD"
|
23 |
|
24 |
|
|
|
149 |
def get_base_dimension(self):
|
150 |
return self.__data.get("base_dimension", None)
|
151 |
|
152 |
+
def get_action_data(self) -> dict:
|
153 |
+
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
|
154 |
+
return self.__data.get("action_data", {})
|
155 |
+
|
156 |
def get_raw(self) -> dict:
|
157 |
return self.__data.copy()
|
158 |
|
internals/pipelines/controlnets.py
CHANGED
@@ -151,7 +151,6 @@ class ControlNet(AbstractPipeline):
|
|
151 |
|
152 |
self.__load_pipeline(model, pipeline_type)
|
153 |
|
154 |
-
self.network_model = model
|
155 |
self.__current_task_name = task_name
|
156 |
|
157 |
clear_cuda_and_gc()
|
@@ -247,6 +246,11 @@ class ControlNet(AbstractPipeline):
|
|
247 |
if hasattr(self, "pipe2"):
|
248 |
setattr(self.pipe2, "adapter", network_model)
|
249 |
|
|
|
|
|
|
|
|
|
|
|
250 |
clear_cuda_and_gc()
|
251 |
|
252 |
def process(self, **kwargs):
|
|
|
151 |
|
152 |
self.__load_pipeline(model, pipeline_type)
|
153 |
|
|
|
154 |
self.__current_task_name = task_name
|
155 |
|
156 |
clear_cuda_and_gc()
|
|
|
246 |
if hasattr(self, "pipe2"):
|
247 |
setattr(self.pipe2, "adapter", network_model)
|
248 |
|
249 |
+
if hasattr(self, "pipe"):
|
250 |
+
self.pipe = self.pipe.to("cuda")
|
251 |
+
if hasattr(self, "pipe2"):
|
252 |
+
self.pipe2 = self.pipe2.to("cuda")
|
253 |
+
|
254 |
clear_cuda_and_gc()
|
255 |
|
256 |
def process(self, **kwargs):
|
internals/pipelines/high_res.py
CHANGED
@@ -5,7 +5,8 @@ from PIL import Image
|
|
5 |
|
6 |
from internals.data.result import Result
|
7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
8 |
-
from internals.util.
|
|
|
9 |
|
10 |
|
11 |
class HighRes(AbstractPipeline):
|
@@ -32,6 +33,8 @@ class HighRes(AbstractPipeline):
|
|
32 |
guidance_scale: int = 9,
|
33 |
**kwargs,
|
34 |
):
|
|
|
|
|
35 |
images = [image.resize((width, height)) for image in images]
|
36 |
kwargs = {
|
37 |
"prompt": prompt,
|
|
|
5 |
|
6 |
from internals.data.result import Result
|
7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
8 |
+
from internals.util.cache import clear_cuda_and_gc
|
9 |
+
from internals.util.config import get_base_dimension, get_model_dir
|
10 |
|
11 |
|
12 |
class HighRes(AbstractPipeline):
|
|
|
33 |
guidance_scale: int = 9,
|
34 |
**kwargs,
|
35 |
):
|
36 |
+
clear_cuda_and_gc()
|
37 |
+
|
38 |
images = [image.resize((width, height)) for image in images]
|
39 |
kwargs = {
|
40 |
"prompt": prompt,
|
internals/util/__init__.py
CHANGED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from internals.util.config import get_root_dir
|
4 |
+
|
5 |
+
|
6 |
+
def getcwd():
|
7 |
+
return get_root_dir()
|
internals/util/lora_style.py
CHANGED
@@ -10,8 +10,8 @@ from lora_diffusion import patch_pipe, tune_lora_scale
|
|
10 |
from pydash import chain
|
11 |
|
12 |
from internals.data.dataAccessor import getStyles
|
13 |
-
from internals.util.config import get_is_sdxl
|
14 |
from internals.util.commons import download_file
|
|
|
15 |
|
16 |
|
17 |
class LoraStyle:
|
@@ -113,9 +113,6 @@ class LoraStyle:
|
|
113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
116 |
-
if get_is_sdxl():
|
117 |
-
print("Warning: Lora is not supported on SDXL")
|
118 |
-
return self.EmptyLoraPatcher(pipe)
|
119 |
|
120 |
if key in self.__styles:
|
121 |
style = self.__styles[key]
|
|
|
10 |
from pydash import chain
|
11 |
|
12 |
from internals.data.dataAccessor import getStyles
|
|
|
13 |
from internals.util.commons import download_file
|
14 |
+
from internals.util.config import get_is_sdxl
|
15 |
|
16 |
|
17 |
class LoraStyle:
|
|
|
113 |
) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
|
114 |
"Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
|
115 |
pipe = [pipe] if not isinstance(pipe, list) else pipe
|
|
|
|
|
|
|
116 |
|
117 |
if key in self.__styles:
|
118 |
style = self.__styles[key]
|
internals/util/slack.py
CHANGED
@@ -55,7 +55,12 @@ class Slack:
|
|
55 |
def auto_send_alert(self, func):
|
56 |
def inner(*args, **kwargs):
|
57 |
rargs = func(*args, **kwargs)
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
return rargs
|
60 |
|
61 |
return inner
|
|
|
55 |
def auto_send_alert(self, func):
|
56 |
def inner(*args, **kwargs):
|
57 |
rargs = func(*args, **kwargs)
|
58 |
+
task = Task({})
|
59 |
+
for arg in args:
|
60 |
+
if type(arg) is Task:
|
61 |
+
task = arg
|
62 |
+
break
|
63 |
+
self.send_alert(task, rargs)
|
64 |
return rargs
|
65 |
|
66 |
return inner
|