Upload folder using huggingface_hub
Browse files- handler.py +28 -1
- inference.py +45 -31
- inference2.py +8 -6
- internals/data/task.py +1 -1
- internals/pipelines/commons.py +8 -3
- internals/pipelines/controlnets.py +36 -9
- internals/pipelines/img_classifier.py +8 -0
- internals/pipelines/img_to_text.py +8 -0
- internals/pipelines/pose_detector.py +2 -0
- internals/pipelines/prompt_modifier.py +7 -0
- internals/pipelines/replace_background.py +2 -1
- internals/pipelines/safety_checker.py +11 -0
- internals/util/config.py +17 -0
- internals/util/model_downloader.py +151 -0
- requirements.txt +13 -14
handler.py
CHANGED
@@ -1,11 +1,38 @@
|
|
|
|
|
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
|
3 |
from inference import model_fn, predict_fn
|
|
|
4 |
|
5 |
|
6 |
class EndpointHandler:
|
7 |
def __init__(self, path=""):
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
11 |
return predict_fn(data, None)
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
from inference import model_fn, predict_fn
|
7 |
+
from internals.util.model_downloader import BaseModelDownloader
|
8 |
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
+
self.model_dir = path
|
13 |
+
|
14 |
+
if os.path.exists(path + "/inference.json"):
|
15 |
+
with open(path + "/inference.json", "r") as f:
|
16 |
+
config = json.loads(f.read())
|
17 |
+
if config.get("model_type") == "huggingface":
|
18 |
+
self.model_dir = config["model_path"]
|
19 |
+
if config.get("model_type") == "s3":
|
20 |
+
s3_config = config["model_path"]["s3"]
|
21 |
+
base_url = s3_config["base_url"]
|
22 |
+
|
23 |
+
urls = [base_url + item for item in s3_config["paths"]]
|
24 |
+
out_dir = Path.home() / ".cache" / "base_model"
|
25 |
+
if out_dir.exists():
|
26 |
+
print("Model already exist")
|
27 |
+
else:
|
28 |
+
print("Downloading model")
|
29 |
+
BaseModelDownloader(
|
30 |
+
urls, s3_config["paths"], out_dir
|
31 |
+
).download()
|
32 |
+
|
33 |
+
self.model_dir = str(out_dir)
|
34 |
+
|
35 |
+
return model_fn(self.model_dir)
|
36 |
|
37 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
38 |
return predict_fn(data, None)
|
inference.py
CHANGED
@@ -15,18 +15,13 @@ from internals.pipelines.safety_checker import SafetyChecker
|
|
15 |
from internals.util.anomaly import remove_colors
|
16 |
from internals.util.args import apply_style_args
|
17 |
from internals.util.avatar import Avatar
|
18 |
-
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda,
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
)
|
25 |
-
from internals.util.config import (
|
26 |
-
num_return_sequences,
|
27 |
-
set_configs_from_task,
|
28 |
-
set_root_dir,
|
29 |
-
)
|
30 |
from internals.util.failure_hander import FailureHandler
|
31 |
from internals.util.lora_style import LoraStyle
|
32 |
from internals.util.slack import Slack
|
@@ -442,27 +437,48 @@ def inpaint(task: Task):
|
|
442 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
443 |
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
def model_fn(model_dir):
|
446 |
print("Logs: model loaded .... starts")
|
447 |
|
|
|
448 |
set_root_dir(__file__)
|
449 |
|
450 |
FailureHandler.register()
|
451 |
|
452 |
avatar.load_local(model_dir)
|
453 |
|
454 |
-
prompt_modifier.load()
|
455 |
-
pose_detector.load()
|
456 |
-
img2text.load()
|
457 |
-
img_classifier.load()
|
458 |
-
|
459 |
lora_style.load(model_dir)
|
460 |
-
safety_checker.load()
|
461 |
-
|
462 |
-
controlnet.load(model_dir)
|
463 |
-
text2img_pipe.load(model_dir)
|
464 |
-
img2img_pipe.create(text2img_pipe)
|
465 |
-
inpainter.create(text2img_pipe)
|
466 |
|
467 |
print("Logs: model loaded ....")
|
468 |
return
|
@@ -479,10 +495,8 @@ def predict_fn(data, pipe):
|
|
479 |
# Set set_environment
|
480 |
set_configs_from_task(task)
|
481 |
|
482 |
-
#
|
483 |
-
|
484 |
-
safety_checker.apply(img2img_pipe)
|
485 |
-
safety_checker.apply(controlnet)
|
486 |
|
487 |
# Apply arguments
|
488 |
apply_style_args(data)
|
@@ -497,10 +511,10 @@ def predict_fn(data, pipe):
|
|
497 |
|
498 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
499 |
# character sheet
|
500 |
-
if "character sheet" in task.get_prompt().lower():
|
501 |
-
|
502 |
-
else:
|
503 |
-
|
504 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
505 |
return img2img(task)
|
506 |
elif task_type == TaskType.CANNY:
|
|
|
15 |
from internals.util.anomaly import remove_colors
|
16 |
from internals.util.args import apply_style_args
|
17 |
from internals.util.avatar import Avatar
|
18 |
+
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
|
19 |
+
clear_cuda_and_gc)
|
20 |
+
from internals.util.commons import (download_image, pickPoses, upload_image,
|
21 |
+
upload_images)
|
22 |
+
from internals.util.config import (get_model_dir, num_return_sequences,
|
23 |
+
set_configs_from_task, set_model_dir,
|
24 |
+
set_root_dir)
|
|
|
|
|
|
|
|
|
|
|
25 |
from internals.util.failure_hander import FailureHandler
|
26 |
from internals.util.lora_style import LoraStyle
|
27 |
from internals.util.slack import Slack
|
|
|
437 |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
438 |
|
439 |
|
440 |
+
def load_model_by_task(task: Task):
|
441 |
+
if (
|
442 |
+
task.get_type()
|
443 |
+
in [
|
444 |
+
TaskType.TEXT_TO_IMAGE,
|
445 |
+
TaskType.IMAGE_TO_IMAGE,
|
446 |
+
TaskType.INPAINT,
|
447 |
+
]
|
448 |
+
and not text2img_pipe.is_loaded()
|
449 |
+
):
|
450 |
+
text2img_pipe.load(get_model_dir())
|
451 |
+
img2img_pipe.create(text2img_pipe)
|
452 |
+
inpainter.create(text2img_pipe)
|
453 |
+
|
454 |
+
safety_checker.apply(text2img_pipe)
|
455 |
+
safety_checker.apply(img2img_pipe)
|
456 |
+
else:
|
457 |
+
if task.get_type() == TaskType.TILE_UPSCALE:
|
458 |
+
controlnet.load_tile_upscaler()
|
459 |
+
elif task.get_type() == TaskType.CANNY:
|
460 |
+
controlnet.load_canny()
|
461 |
+
elif task.get_type() == TaskType.SCRIBBLE:
|
462 |
+
controlnet.load_scribble()
|
463 |
+
elif task.get_type() == TaskType.LINEARART:
|
464 |
+
controlnet.load_linearart()
|
465 |
+
elif task.get_type() == TaskType.POSE:
|
466 |
+
controlnet.load_pose()
|
467 |
+
|
468 |
+
safety_checker.apply(controlnet)
|
469 |
+
|
470 |
+
|
471 |
def model_fn(model_dir):
|
472 |
print("Logs: model loaded .... starts")
|
473 |
|
474 |
+
set_model_dir(model_dir)
|
475 |
set_root_dir(__file__)
|
476 |
|
477 |
FailureHandler.register()
|
478 |
|
479 |
avatar.load_local(model_dir)
|
480 |
|
|
|
|
|
|
|
|
|
|
|
481 |
lora_style.load(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
print("Logs: model loaded ....")
|
484 |
return
|
|
|
495 |
# Set set_environment
|
496 |
set_configs_from_task(task)
|
497 |
|
498 |
+
# Load model based on task
|
499 |
+
load_model_by_task(task)
|
|
|
|
|
500 |
|
501 |
# Apply arguments
|
502 |
apply_style_args(data)
|
|
|
511 |
|
512 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
513 |
# character sheet
|
514 |
+
# if "character sheet" in task.get_prompt().lower():
|
515 |
+
# return pose(task, s3_outkey="", poses=pickPoses())
|
516 |
+
# else:
|
517 |
+
return text2img(task)
|
518 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
519 |
return img2img(task)
|
520 |
elif task_type == TaskType.CANNY:
|
inference2.py
CHANGED
@@ -7,17 +7,18 @@ from internals.data.task import ModelType, Task, TaskType
|
|
7 |
from internals.pipelines.inpainter import InPainter
|
8 |
from internals.pipelines.object_remove import ObjectRemoval
|
9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
10 |
-
from internals.pipelines.remove_background import
|
11 |
-
RemoveBackgroundV2)
|
12 |
from internals.pipelines.replace_background import ReplaceBackground
|
13 |
from internals.pipelines.safety_checker import SafetyChecker
|
14 |
from internals.pipelines.upscaler import Upscaler
|
15 |
from internals.util.avatar import Avatar
|
16 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
17 |
-
from internals.util.commons import
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
from internals.util.failure_hander import FailureHandler
|
22 |
from internals.util.slack import Slack
|
23 |
|
@@ -189,6 +190,7 @@ def predict_fn(data, pipe):
|
|
189 |
|
190 |
# Apply safety checker based on environment
|
191 |
safety_checker.apply(inpainter)
|
|
|
192 |
|
193 |
# Fetch avatars
|
194 |
avatar.fetch_from_network(task.get_model_id())
|
|
|
7 |
from internals.pipelines.inpainter import InPainter
|
8 |
from internals.pipelines.object_remove import ObjectRemoval
|
9 |
from internals.pipelines.prompt_modifier import PromptModifier
|
10 |
+
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
|
|
|
11 |
from internals.pipelines.replace_background import ReplaceBackground
|
12 |
from internals.pipelines.safety_checker import SafetyChecker
|
13 |
from internals.pipelines.upscaler import Upscaler
|
14 |
from internals.util.avatar import Avatar
|
15 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
16 |
+
from internals.util.commons import construct_default_s3_url, upload_image, upload_images
|
17 |
+
from internals.util.config import (
|
18 |
+
num_return_sequences,
|
19 |
+
set_configs_from_task,
|
20 |
+
set_root_dir,
|
21 |
+
)
|
22 |
from internals.util.failure_hander import FailureHandler
|
23 |
from internals.util.slack import Slack
|
24 |
|
|
|
190 |
|
191 |
# Apply safety checker based on environment
|
192 |
safety_checker.apply(inpainter)
|
193 |
+
safety_checker.apply(replace_background)
|
194 |
|
195 |
# Fetch avatars
|
196 |
avatar.fetch_from_network(task.get_model_id())
|
internals/data/task.py
CHANGED
@@ -30,7 +30,7 @@ class Task:
|
|
30 |
def __init__(self, data):
|
31 |
self.__data = data
|
32 |
if data.get("seed", -1) == None or self.get_seed() == -1:
|
33 |
-
self.__data["seed"] = np.random.randint(0, np.iinfo(np.
|
34 |
prompt = data.get("prompt", "")
|
35 |
if prompt is None:
|
36 |
self.__data["prompt"] = ""
|
|
|
30 |
def __init__(self, data):
|
31 |
self.__data = data
|
32 |
if data.get("seed", -1) == None or self.get_seed() == -1:
|
33 |
+
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max)
|
34 |
prompt = data.get("prompt", "")
|
35 |
if prompt is None:
|
36 |
self.__data["prompt"] = ""
|
internals/pipelines/commons.py
CHANGED
@@ -7,7 +7,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
|
|
7 |
from internals.data.result import Result
|
8 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
9 |
from internals.util.commons import disable_safety_checker, download_image
|
10 |
-
from internals.util.config import num_return_sequences
|
11 |
|
12 |
|
13 |
class AbstractPipeline:
|
@@ -28,10 +28,15 @@ class Text2Img(AbstractPipeline):
|
|
28 |
|
29 |
def load(self, model_dir: str):
|
30 |
self.pipe = two_step_pipeline.from_pretrained(
|
31 |
-
model_dir, torch_dtype=torch.float16
|
32 |
).to("cuda")
|
33 |
self.__patch()
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
def create(self, pipeline: AbstractPipeline):
|
36 |
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
|
37 |
self.__patch()
|
@@ -115,7 +120,7 @@ class Text2Img(AbstractPipeline):
|
|
115 |
class Img2Img(AbstractPipeline):
|
116 |
def load(self, model_dir: str):
|
117 |
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
118 |
-
model_dir, torch_dtype=torch.float16
|
119 |
).to("cuda")
|
120 |
self.__patch()
|
121 |
|
|
|
7 |
from internals.data.result import Result
|
8 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
9 |
from internals.util.commons import disable_safety_checker, download_image
|
10 |
+
from internals.util.config import get_hf_token, num_return_sequences
|
11 |
|
12 |
|
13 |
class AbstractPipeline:
|
|
|
28 |
|
29 |
def load(self, model_dir: str):
|
30 |
self.pipe = two_step_pipeline.from_pretrained(
|
31 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
32 |
).to("cuda")
|
33 |
self.__patch()
|
34 |
|
35 |
+
def is_loaded(self):
|
36 |
+
if hasattr(self, "pipe"):
|
37 |
+
return True
|
38 |
+
return False
|
39 |
+
|
40 |
def create(self, pipeline: AbstractPipeline):
|
41 |
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
|
42 |
self.__patch()
|
|
|
120 |
class Img2Img(AbstractPipeline):
|
121 |
def load(self, model_dir: str):
|
122 |
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
123 |
+
model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
|
124 |
).to("cuda")
|
125 |
self.__patch()
|
126 |
|
internals/pipelines/controlnets.py
CHANGED
@@ -4,33 +4,43 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
7 |
-
from diffusers import (
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
10 |
from PIL import Image
|
11 |
from torch.nn import Linear
|
12 |
from tqdm import gui
|
13 |
|
14 |
from internals.data.result import Result
|
15 |
from internals.pipelines.commons import AbstractPipeline
|
16 |
-
from internals.pipelines.tileUpscalePipeline import
|
17 |
-
StableDiffusionControlNetImg2ImgPipeline
|
|
|
18 |
from internals.util.cache import clear_cuda_and_gc
|
19 |
from internals.util.commons import download_image
|
|
|
20 |
|
21 |
|
22 |
class ControlNet(AbstractPipeline):
|
23 |
__current_task_name = ""
|
|
|
24 |
|
25 |
-
def load(self
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
# controlnet pipeline for tile upscaler
|
30 |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
31 |
-
|
32 |
controlnet=self.controlnet,
|
33 |
torch_dtype=torch.float16,
|
|
|
34 |
).to("cuda")
|
35 |
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
36 |
pipe.enable_model_cpu_offload()
|
@@ -43,6 +53,8 @@ class ControlNet(AbstractPipeline):
|
|
43 |
pipe2.enable_xformers_memory_efficient_attention()
|
44 |
self.pipe2 = pipe2
|
45 |
|
|
|
|
|
46 |
def load_canny(self):
|
47 |
if self.__current_task_name == "canny":
|
48 |
return
|
@@ -51,6 +63,9 @@ class ControlNet(AbstractPipeline):
|
|
51 |
).to("cuda")
|
52 |
self.__current_task_name = "canny"
|
53 |
self.controlnet = canny
|
|
|
|
|
|
|
54 |
if hasattr(self, "pipe"):
|
55 |
self.pipe.controlnet = canny
|
56 |
if hasattr(self, "pipe2"):
|
@@ -65,6 +80,9 @@ class ControlNet(AbstractPipeline):
|
|
65 |
).to("cuda")
|
66 |
self.__current_task_name = "pose"
|
67 |
self.controlnet = pose
|
|
|
|
|
|
|
68 |
if hasattr(self, "pipe"):
|
69 |
self.pipe.controlnet = pose
|
70 |
if hasattr(self, "pipe2"):
|
@@ -79,6 +97,9 @@ class ControlNet(AbstractPipeline):
|
|
79 |
).to("cuda")
|
80 |
self.__current_task_name = "tile_upscaler"
|
81 |
self.controlnet = tile_upscaler
|
|
|
|
|
|
|
82 |
if hasattr(self, "pipe"):
|
83 |
self.pipe.controlnet = tile_upscaler
|
84 |
if hasattr(self, "pipe2"):
|
@@ -93,6 +114,9 @@ class ControlNet(AbstractPipeline):
|
|
93 |
).to("cuda")
|
94 |
self.__current_task_name = "scribble"
|
95 |
self.controlnet = scribble
|
|
|
|
|
|
|
96 |
if hasattr(self, "pipe"):
|
97 |
self.pipe.controlnet = scribble
|
98 |
if hasattr(self, "pipe2"):
|
@@ -108,6 +132,9 @@ class ControlNet(AbstractPipeline):
|
|
108 |
).to("cuda")
|
109 |
self.__current_task_name = "linearart"
|
110 |
self.controlnet = linearart
|
|
|
|
|
|
|
111 |
if hasattr(self, "pipe"):
|
112 |
self.pipe.controlnet = linearart
|
113 |
if hasattr(self, "pipe2"):
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
7 |
+
from diffusers import (
|
8 |
+
ControlNetModel,
|
9 |
+
DiffusionPipeline,
|
10 |
+
StableDiffusionControlNetPipeline,
|
11 |
+
UniPCMultistepScheduler,
|
12 |
+
)
|
13 |
from PIL import Image
|
14 |
from torch.nn import Linear
|
15 |
from tqdm import gui
|
16 |
|
17 |
from internals.data.result import Result
|
18 |
from internals.pipelines.commons import AbstractPipeline
|
19 |
+
from internals.pipelines.tileUpscalePipeline import (
|
20 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
21 |
+
)
|
22 |
from internals.util.cache import clear_cuda_and_gc
|
23 |
from internals.util.commons import download_image
|
24 |
+
from internals.util.config import get_hf_token, get_model_dir
|
25 |
|
26 |
|
27 |
class ControlNet(AbstractPipeline):
|
28 |
__current_task_name = ""
|
29 |
+
__loaded = False
|
30 |
|
31 |
+
def load(self):
|
32 |
+
if self.__loaded:
|
33 |
+
return
|
34 |
+
|
35 |
+
if not self.controlnet:
|
36 |
+
self.load_pose()
|
37 |
|
38 |
# controlnet pipeline for tile upscaler
|
39 |
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
40 |
+
get_model_dir(),
|
41 |
controlnet=self.controlnet,
|
42 |
torch_dtype=torch.float16,
|
43 |
+
use_auth_token=get_hf_token(),
|
44 |
).to("cuda")
|
45 |
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
46 |
pipe.enable_model_cpu_offload()
|
|
|
53 |
pipe2.enable_xformers_memory_efficient_attention()
|
54 |
self.pipe2 = pipe2
|
55 |
|
56 |
+
self.__loaded = True
|
57 |
+
|
58 |
def load_canny(self):
|
59 |
if self.__current_task_name == "canny":
|
60 |
return
|
|
|
63 |
).to("cuda")
|
64 |
self.__current_task_name = "canny"
|
65 |
self.controlnet = canny
|
66 |
+
|
67 |
+
self.load()
|
68 |
+
|
69 |
if hasattr(self, "pipe"):
|
70 |
self.pipe.controlnet = canny
|
71 |
if hasattr(self, "pipe2"):
|
|
|
80 |
).to("cuda")
|
81 |
self.__current_task_name = "pose"
|
82 |
self.controlnet = pose
|
83 |
+
|
84 |
+
self.load()
|
85 |
+
|
86 |
if hasattr(self, "pipe"):
|
87 |
self.pipe.controlnet = pose
|
88 |
if hasattr(self, "pipe2"):
|
|
|
97 |
).to("cuda")
|
98 |
self.__current_task_name = "tile_upscaler"
|
99 |
self.controlnet = tile_upscaler
|
100 |
+
|
101 |
+
self.load()
|
102 |
+
|
103 |
if hasattr(self, "pipe"):
|
104 |
self.pipe.controlnet = tile_upscaler
|
105 |
if hasattr(self, "pipe2"):
|
|
|
114 |
).to("cuda")
|
115 |
self.__current_task_name = "scribble"
|
116 |
self.controlnet = scribble
|
117 |
+
|
118 |
+
self.load()
|
119 |
+
|
120 |
if hasattr(self, "pipe"):
|
121 |
self.pipe.controlnet = scribble
|
122 |
if hasattr(self, "pipe2"):
|
|
|
132 |
).to("cuda")
|
133 |
self.__current_task_name = "linearart"
|
134 |
self.controlnet = linearart
|
135 |
+
|
136 |
+
self.load()
|
137 |
+
|
138 |
if hasattr(self, "pipe"):
|
139 |
self.pipe.controlnet = linearart
|
140 |
if hasattr(self, "pipe2"):
|
internals/pipelines/img_classifier.py
CHANGED
@@ -6,16 +6,24 @@ from internals.util.commons import download_image
|
|
6 |
|
7 |
|
8 |
class ImageClassifier:
|
|
|
|
|
9 |
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
|
10 |
self.__candidates = candidates
|
11 |
|
12 |
def load(self):
|
|
|
|
|
|
|
13 |
self.pipe = pipeline(
|
14 |
"zero-shot-image-classification",
|
15 |
model="philschmid/clip-zero-shot-image-classification",
|
16 |
)
|
17 |
|
|
|
|
|
18 |
def classify(self, image_url: str, width: int, height: int) -> str:
|
|
|
19 |
image = download_image(image_url).resize((width, height))
|
20 |
results = self.pipe.__call__([image], candidate_labels=self.__candidates)
|
21 |
results = results[0]
|
|
|
6 |
|
7 |
|
8 |
class ImageClassifier:
|
9 |
+
__loaded = False
|
10 |
+
|
11 |
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
|
12 |
self.__candidates = candidates
|
13 |
|
14 |
def load(self):
|
15 |
+
if self.__loaded:
|
16 |
+
return
|
17 |
+
|
18 |
self.pipe = pipeline(
|
19 |
"zero-shot-image-classification",
|
20 |
model="philschmid/clip-zero-shot-image-classification",
|
21 |
)
|
22 |
|
23 |
+
self.__loaded = True
|
24 |
+
|
25 |
def classify(self, image_url: str, width: int, height: int) -> str:
|
26 |
+
self.load()
|
27 |
image = download_image(image_url).resize((width, height))
|
28 |
results = self.pipe.__call__([image], candidate_labels=self.__candidates)
|
29 |
results = results[0]
|
internals/pipelines/img_to_text.py
CHANGED
@@ -8,7 +8,12 @@ from internals.util.commons import download_image
|
|
8 |
|
9 |
|
10 |
class Image2Text:
|
|
|
|
|
11 |
def load(self):
|
|
|
|
|
|
|
12 |
self.processor = BlipProcessor.from_pretrained(
|
13 |
"Salesforce/blip-image-captioning-large"
|
14 |
)
|
@@ -16,7 +21,10 @@ class Image2Text:
|
|
16 |
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
|
17 |
).to("cuda")
|
18 |
|
|
|
|
|
19 |
def process(self, imageUrl: str) -> str:
|
|
|
20 |
image = download_image(imageUrl).resize((512, 512))
|
21 |
inputs = self.processor.__call__(image, return_tensors="pt").to(
|
22 |
"cuda", torch.float16
|
|
|
8 |
|
9 |
|
10 |
class Image2Text:
|
11 |
+
__loaded = False
|
12 |
+
|
13 |
def load(self):
|
14 |
+
if self.__loaded:
|
15 |
+
return
|
16 |
+
|
17 |
self.processor = BlipProcessor.from_pretrained(
|
18 |
"Salesforce/blip-image-captioning-large"
|
19 |
)
|
|
|
21 |
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
|
22 |
).to("cuda")
|
23 |
|
24 |
+
self.__loaded = True
|
25 |
+
|
26 |
def process(self, imageUrl: str) -> str:
|
27 |
+
self.load()
|
28 |
image = download_image(imageUrl).resize((512, 512))
|
29 |
inputs = self.processor.__call__(image, return_tensors="pt").to(
|
30 |
"cuda", torch.float16
|
internals/pipelines/pose_detector.py
CHANGED
@@ -36,6 +36,7 @@ class PoseDetector:
|
|
36 |
client_coordinates: Optional[dict],
|
37 |
) -> Image.Image:
|
38 |
"Infer pose coordinates from image, map head and body coordinates to infered ones, create pose"
|
|
|
39 |
if type(image) is str:
|
40 |
image = download_image(image)
|
41 |
|
@@ -103,6 +104,7 @@ class PoseDetector:
|
|
103 |
return image
|
104 |
|
105 |
def infer(self, image: Union[str, Image.Image], width, height) -> dict:
|
|
|
106 |
candidate = []
|
107 |
subset = []
|
108 |
|
|
|
36 |
client_coordinates: Optional[dict],
|
37 |
) -> Image.Image:
|
38 |
"Infer pose coordinates from image, map head and body coordinates to infered ones, create pose"
|
39 |
+
self.load()
|
40 |
if type(image) is str:
|
41 |
image = download_image(image)
|
42 |
|
|
|
104 |
return image
|
105 |
|
106 |
def infer(self, image: Union[str, Image.Image], width, height) -> dict:
|
107 |
+
self.load()
|
108 |
candidate = []
|
109 |
subset = []
|
110 |
|
internals/pipelines/prompt_modifier.py
CHANGED
@@ -4,11 +4,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
4 |
|
5 |
|
6 |
class PromptModifier:
|
|
|
|
|
7 |
def __init__(self, num_of_sequences: Optional[int] = 4):
|
8 |
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
9 |
self.__num_of_sequences = num_of_sequences
|
10 |
|
11 |
def load(self):
|
|
|
|
|
12 |
self.prompter_model = AutoModelForCausalLM.from_pretrained(
|
13 |
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
14 |
)
|
@@ -18,7 +22,10 @@ class PromptModifier:
|
|
18 |
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
|
19 |
self.prompter_tokenizer.padding_side = "left"
|
20 |
|
|
|
|
|
21 |
def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
|
|
|
22 |
eos_id = self.prompter_tokenizer.eos_token_id
|
23 |
# restricted_words_list = ["octane", "cyber"]
|
24 |
# restricted_words_token_ids = prompter_tokenizer(
|
|
|
4 |
|
5 |
|
6 |
class PromptModifier:
|
7 |
+
__loaded = False
|
8 |
+
|
9 |
def __init__(self, num_of_sequences: Optional[int] = 4):
|
10 |
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
11 |
self.__num_of_sequences = num_of_sequences
|
12 |
|
13 |
def load(self):
|
14 |
+
if self.__loaded:
|
15 |
+
return
|
16 |
self.prompter_model = AutoModelForCausalLM.from_pretrained(
|
17 |
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
18 |
)
|
|
|
22 |
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
|
23 |
self.prompter_tokenizer.padding_side = "left"
|
24 |
|
25 |
+
self.__loaded = True
|
26 |
+
|
27 |
def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
|
28 |
+
self.load()
|
29 |
eos_id = self.prompter_tokenizer.eos_token_id
|
30 |
# restricted_words_list = ["octane", "cyber"]
|
31 |
# restricted_words_token_ids = prompter_tokenizer(
|
internals/pipelines/replace_background.py
CHANGED
@@ -12,13 +12,14 @@ from PIL import Image, ImageFilter, ImageOps
|
|
12 |
|
13 |
import internals.util.image as ImageUtil
|
14 |
from internals.data.result import Result
|
|
|
15 |
from internals.pipelines.controlnets import ControlNet
|
16 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
17 |
from internals.pipelines.upscaler import Upscaler
|
18 |
from internals.util.commons import download_image
|
19 |
|
20 |
|
21 |
-
class ReplaceBackground:
|
22 |
def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
|
23 |
controlnet = ControlNetModel.from_pretrained(
|
24 |
"lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
|
|
|
12 |
|
13 |
import internals.util.image as ImageUtil
|
14 |
from internals.data.result import Result
|
15 |
+
from internals.pipelines.commons import AbstractPipeline
|
16 |
from internals.pipelines.controlnets import ControlNet
|
17 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
18 |
from internals.pipelines.upscaler import Upscaler
|
19 |
from internals.util.commons import download_image
|
20 |
|
21 |
|
22 |
+
class ReplaceBackground(AbstractPipeline):
|
23 |
def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
|
24 |
controlnet = ControlNetModel.from_pretrained(
|
25 |
"lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
|
internals/pipelines/safety_checker.py
CHANGED
@@ -18,13 +18,24 @@ def cosine_distance(image_embeds, text_embeds):
|
|
18 |
|
19 |
|
20 |
class SafetyChecker:
|
|
|
|
|
21 |
def load(self):
|
|
|
|
|
|
|
22 |
self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
|
23 |
"CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
|
24 |
).to("cuda")
|
25 |
|
|
|
|
|
26 |
def apply(self, pipeline: AbstractPipeline):
|
|
|
|
|
27 |
model = self.model if not get_nsfw_access() else None
|
|
|
|
|
28 |
if hasattr(pipeline, "pipe"):
|
29 |
pipeline.pipe.safety_checker = model
|
30 |
if hasattr(pipeline, "pipe2"):
|
|
|
18 |
|
19 |
|
20 |
class SafetyChecker:
|
21 |
+
__loaded = False
|
22 |
+
|
23 |
def load(self):
|
24 |
+
if self.__loaded:
|
25 |
+
return
|
26 |
+
|
27 |
self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
|
28 |
"CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
|
29 |
).to("cuda")
|
30 |
|
31 |
+
self.__loaded = True
|
32 |
+
|
33 |
def apply(self, pipeline: AbstractPipeline):
|
34 |
+
self.load()
|
35 |
+
|
36 |
model = self.model if not get_nsfw_access() else None
|
37 |
+
if not pipeline:
|
38 |
+
return
|
39 |
if hasattr(pipeline, "pipe"):
|
40 |
pipeline.pipe.safety_checker = model
|
41 |
if hasattr(pipeline, "pipe2"):
|
internals/util/config.py
CHANGED
@@ -7,10 +7,17 @@ nsfw_threshold = 0.0
|
|
7 |
nsfw_access = False
|
8 |
access_token = ""
|
9 |
root_dir = ""
|
|
|
|
|
10 |
|
11 |
num_return_sequences = 4 # the number of results to generate
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
def set_root_dir(main_file: str):
|
15 |
global root_dir
|
16 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
@@ -28,6 +35,11 @@ def set_configs_from_task(task: Task):
|
|
28 |
access_token = task.get_access_token()
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
def get_root_dir():
|
32 |
global root_dir
|
33 |
return root_dir
|
@@ -48,6 +60,11 @@ def get_nsfw_access():
|
|
48 |
return nsfw_access
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
def api_headers():
|
52 |
return {
|
53 |
"Access-Token": access_token,
|
|
|
7 |
nsfw_access = False
|
8 |
access_token = ""
|
9 |
root_dir = ""
|
10 |
+
model_dir = ""
|
11 |
+
hf_token = "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"
|
12 |
|
13 |
num_return_sequences = 4 # the number of results to generate
|
14 |
|
15 |
|
16 |
+
def set_model_dir(dir: str):
|
17 |
+
global model_dir
|
18 |
+
model_dir = dir
|
19 |
+
|
20 |
+
|
21 |
def set_root_dir(main_file: str):
|
22 |
global root_dir
|
23 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
|
|
35 |
access_token = task.get_access_token()
|
36 |
|
37 |
|
38 |
+
def get_model_dir():
|
39 |
+
global model_dir
|
40 |
+
return model_dir
|
41 |
+
|
42 |
+
|
43 |
def get_root_dir():
|
44 |
global root_dir
|
45 |
return root_dir
|
|
|
60 |
return nsfw_access
|
61 |
|
62 |
|
63 |
+
def get_hf_token():
|
64 |
+
global hf_token
|
65 |
+
return hf_token
|
66 |
+
|
67 |
+
|
68 |
def api_headers():
|
69 |
return {
|
70 |
"Access-Token": access_token,
|
internals/util/model_downloader.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from threading import Thread
|
6 |
+
from typing import Any, Dict, List
|
7 |
+
|
8 |
+
import requests
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
class BaseModelDownloader:
|
13 |
+
"""
|
14 |
+
A utility for fast download of base model from S3 or any CDN served storage.
|
15 |
+
Works by downloading multiple files in parallel and dividing large files
|
16 |
+
into smaller chunks and combining them at the end.
|
17 |
+
|
18 |
+
Currently it uses multithreading (not multiprocessing) assuming GIL won't
|
19 |
+
interfere with network/disk IO.
|
20 |
+
|
21 |
+
Created by: KP
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path):
|
25 |
+
self.urls = urls
|
26 |
+
self.url_paths = url_paths
|
27 |
+
shutil.rmtree(out_dir, ignore_errors=True)
|
28 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
self.out_dir = out_dir
|
30 |
+
|
31 |
+
def download(self):
|
32 |
+
threads = []
|
33 |
+
batch_urls = {}
|
34 |
+
|
35 |
+
for url, url_path in zip(self.urls, self.url_paths):
|
36 |
+
out_dir = self.out_dir / url_path
|
37 |
+
self.out_dir.parent.mkdir(parents=True, exist_ok=True)
|
38 |
+
if url.endswith(".bin"):
|
39 |
+
if "unet/" in url_path:
|
40 |
+
thread = Thread(
|
41 |
+
target=self.__download_parallel, args=(url, out_dir, 6)
|
42 |
+
)
|
43 |
+
thread.start()
|
44 |
+
threads.append(thread)
|
45 |
+
else:
|
46 |
+
thread = Thread(
|
47 |
+
target=self.__download_files, args=([url], [out_dir])
|
48 |
+
)
|
49 |
+
thread.start()
|
50 |
+
threads.append(thread)
|
51 |
+
pass
|
52 |
+
else:
|
53 |
+
batch_urls[url] = out_dir
|
54 |
+
|
55 |
+
if batch_urls:
|
56 |
+
thread = Thread(
|
57 |
+
target=self.__download_files,
|
58 |
+
args=(list(batch_urls.keys()), list(batch_urls.values())),
|
59 |
+
)
|
60 |
+
thread.start()
|
61 |
+
threads.append(thread)
|
62 |
+
pass
|
63 |
+
|
64 |
+
for thread in threads:
|
65 |
+
thread.join()
|
66 |
+
|
67 |
+
def __download_parallel(self, url, output_filename, num_parts=4):
|
68 |
+
response = requests.head(url)
|
69 |
+
total_size = int(response.headers.get("content-length", 0))
|
70 |
+
print("total_size", total_size)
|
71 |
+
|
72 |
+
chunk_size = total_size // num_parts
|
73 |
+
ranges = [
|
74 |
+
(i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1)
|
75 |
+
]
|
76 |
+
ranges.append((ranges[-1][1] + 1, total_size))
|
77 |
+
|
78 |
+
print(ranges)
|
79 |
+
|
80 |
+
save_dir = Path.home() / ".cache" / "download_parts"
|
81 |
+
os.makedirs(save_dir, exist_ok=True)
|
82 |
+
|
83 |
+
threads = []
|
84 |
+
for i, (start, end) in enumerate(ranges):
|
85 |
+
thread = Thread(
|
86 |
+
target=self.__download_part, args=(url, start, end, i, save_dir)
|
87 |
+
)
|
88 |
+
thread.start()
|
89 |
+
threads.append(thread)
|
90 |
+
|
91 |
+
for thread in threads:
|
92 |
+
thread.join()
|
93 |
+
|
94 |
+
self.__combine_parts(save_dir, output_filename, num_parts)
|
95 |
+
os.rmdir(save_dir)
|
96 |
+
|
97 |
+
def __combine_parts(self, save_dir, output_filename, num_parts):
|
98 |
+
part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)]
|
99 |
+
|
100 |
+
output_filename.parent.mkdir(parents=True, exist_ok=True)
|
101 |
+
with open(output_filename, "wb") as output_file:
|
102 |
+
for part_file in part_files:
|
103 |
+
print("combining: ", part_file)
|
104 |
+
with open(part_file, "rb") as part:
|
105 |
+
output_file.write(part.read())
|
106 |
+
|
107 |
+
out_file_size = output_file.tell()
|
108 |
+
print("out_file_size", out_file_size)
|
109 |
+
|
110 |
+
for part_file in part_files:
|
111 |
+
os.remove(part_file)
|
112 |
+
|
113 |
+
def __download_part(self, url, start_byte, end_byte, part_num, save_dir):
|
114 |
+
headers = {"Range": f"bytes={start_byte}-{end_byte}"}
|
115 |
+
response = requests.get(url, headers=headers, stream=True)
|
116 |
+
|
117 |
+
part_filename = os.path.join(save_dir, f"part_{part_num}.tmp")
|
118 |
+
print("Downloading part: ", url, part_filename, end_byte - start_byte)
|
119 |
+
|
120 |
+
with open(part_filename, "wb") as part_file, tqdm(
|
121 |
+
desc=str(part_filename),
|
122 |
+
total=end_byte - start_byte,
|
123 |
+
unit="B",
|
124 |
+
unit_scale=True,
|
125 |
+
unit_divisor=1024,
|
126 |
+
) as bar:
|
127 |
+
for chunk in response.iter_content(chunk_size=8192):
|
128 |
+
if chunk:
|
129 |
+
size = part_file.write(chunk)
|
130 |
+
bar.update(size)
|
131 |
+
|
132 |
+
return part_filename
|
133 |
+
|
134 |
+
def __download_files(self, urls, out_paths: List[Path]):
|
135 |
+
for url, out_path in zip(urls, out_paths):
|
136 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
137 |
+
with requests.get(url, stream=True) as r:
|
138 |
+
print("Downloading: ", url)
|
139 |
+
total_size = int(r.headers.get("content-length", 0))
|
140 |
+
chunk_size = 8192
|
141 |
+
r.raise_for_status()
|
142 |
+
with open(out_path, "wb") as f, tqdm(
|
143 |
+
desc=str(out_path),
|
144 |
+
total=total_size,
|
145 |
+
unit="B",
|
146 |
+
unit_scale=True,
|
147 |
+
unit_divisor=1024,
|
148 |
+
) as bar:
|
149 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
150 |
+
size = f.write(data)
|
151 |
+
bar.update(size)
|
requirements.txt
CHANGED
@@ -22,20 +22,19 @@ kornia==0.5.0
|
|
22 |
pytorch-lightning==1.2.9
|
23 |
mmpose==0.29.0
|
24 |
mmdet==2.28.2
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
scikit-image
|
33 |
-
omegaconf
|
34 |
-
webdataset
|
35 |
-
git+https://github.com/cloneofsimo/lora.git
|
36 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
37 |
python-dateutil==2.8.2
|
38 |
PyYAML
|
39 |
-
torchvision
|
40 |
-
imgaug
|
41 |
-
tqdm
|
|
|
22 |
pytorch-lightning==1.2.9
|
23 |
mmpose==0.29.0
|
24 |
mmdet==2.28.2
|
25 |
+
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/v1/lora-diffusion-0.1.7.zip
|
26 |
+
mmengine==0.8.4
|
27 |
+
pydash==7.0.6
|
28 |
+
scikit-learn==1.3.0
|
29 |
+
accelerate==0.22.0
|
30 |
+
pandas==2.0.3
|
31 |
+
xformers==0.0.21
|
32 |
+
scikit-image==0.19.3
|
33 |
+
omegaconf==2.3.0
|
34 |
+
webdataset==0.2.48
|
|
|
35 |
https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
|
36 |
python-dateutil==2.8.2
|
37 |
PyYAML
|
38 |
+
torchvision==0.15.2
|
39 |
+
imgaug==0.4.0
|
40 |
+
tqdm==4.64.1
|