Upload folder using huggingface_hub
Browse files- handler.py +2 -23
- inference.py +7 -4
- inference2.py +4 -2
- internals/pipelines/controlnets.py +8 -13
- internals/pipelines/inpainter.py +12 -2
- internals/pipelines/replace_background.py +28 -50
- internals/util/config.py +14 -8
- internals/util/model_loader.py +187 -0
- requirements.txt +1 -1
handler.py
CHANGED
@@ -4,8 +4,8 @@ from pathlib import Path
|
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
from inference import model_fn, predict_fn
|
7 |
-
from internals.util.config import set_hf_cache_dir
|
8 |
-
from internals.util.
|
9 |
|
10 |
|
11 |
class EndpointHandler:
|
@@ -13,27 +13,6 @@ class EndpointHandler:
|
|
13 |
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
|
14 |
self.model_dir = path
|
15 |
|
16 |
-
if os.path.exists(path + "/inference.json"):
|
17 |
-
with open(path + "/inference.json", "r") as f:
|
18 |
-
config = json.loads(f.read())
|
19 |
-
if config.get("model_type") == "huggingface":
|
20 |
-
self.model_dir = config["model_path"]
|
21 |
-
if config.get("model_type") == "s3":
|
22 |
-
s3_config = config["model_path"]["s3"]
|
23 |
-
base_url = s3_config["base_url"]
|
24 |
-
|
25 |
-
urls = [base_url + item for item in s3_config["paths"]]
|
26 |
-
out_dir = Path.home() / ".cache" / "base_model"
|
27 |
-
if out_dir.exists():
|
28 |
-
print("Model already exist")
|
29 |
-
else:
|
30 |
-
print("Downloading model")
|
31 |
-
BaseModelDownloader(
|
32 |
-
urls, s3_config["paths"], out_dir
|
33 |
-
).download()
|
34 |
-
|
35 |
-
self.model_dir = str(out_dir)
|
36 |
-
|
37 |
return model_fn(self.model_dir)
|
38 |
|
39 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
|
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
from inference import model_fn, predict_fn
|
7 |
+
from internals.util.config import set_hf_cache_dir, set_model_config
|
8 |
+
from internals.util.model_loader import load_model_from_config
|
9 |
|
10 |
|
11 |
class EndpointHandler:
|
|
|
13 |
set_hf_cache_dir(Path.home() / ".cache" / "hf_cache")
|
14 |
self.model_dir = path
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
return model_fn(self.model_dir)
|
17 |
|
18 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
inference.py
CHANGED
@@ -21,10 +21,11 @@ from internals.util.avatar import Avatar
|
|
21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
22 |
from internals.util.commons import download_image, upload_image, upload_images
|
23 |
from internals.util.config import (get_model_dir, num_return_sequences,
|
24 |
-
set_configs_from_task,
|
25 |
set_root_dir)
|
26 |
from internals.util.failure_hander import FailureHandler
|
27 |
from internals.util.lora_style import LoraStyle
|
|
|
28 |
from internals.util.slack import Slack
|
29 |
|
30 |
torch.backends.cudnn.benchmark = True
|
@@ -496,13 +497,14 @@ def load_model_by_task(task: Task):
|
|
496 |
):
|
497 |
text2img_pipe.load(get_model_dir())
|
498 |
img2img_pipe.create(text2img_pipe)
|
499 |
-
inpainter.
|
500 |
high_res.load(img2img_pipe)
|
501 |
|
502 |
safety_checker.apply(text2img_pipe)
|
503 |
safety_checker.apply(img2img_pipe)
|
|
|
504 |
elif task.get_type() == TaskType.REPLACE_BG:
|
505 |
-
replace_background.load(
|
506 |
else:
|
507 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
508 |
controlnet.load_tile_upscaler()
|
@@ -521,7 +523,8 @@ def load_model_by_task(task: Task):
|
|
521 |
def model_fn(model_dir):
|
522 |
print("Logs: model loaded .... starts")
|
523 |
|
524 |
-
|
|
|
525 |
set_root_dir(__file__)
|
526 |
|
527 |
FailureHandler.register()
|
|
|
21 |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
22 |
from internals.util.commons import download_image, upload_image, upload_images
|
23 |
from internals.util.config import (get_model_dir, num_return_sequences,
|
24 |
+
set_configs_from_task, set_model_config,
|
25 |
set_root_dir)
|
26 |
from internals.util.failure_hander import FailureHandler
|
27 |
from internals.util.lora_style import LoraStyle
|
28 |
+
from internals.util.model_loader import load_model_from_config
|
29 |
from internals.util.slack import Slack
|
30 |
|
31 |
torch.backends.cudnn.benchmark = True
|
|
|
497 |
):
|
498 |
text2img_pipe.load(get_model_dir())
|
499 |
img2img_pipe.create(text2img_pipe)
|
500 |
+
inpainter.load()
|
501 |
high_res.load(img2img_pipe)
|
502 |
|
503 |
safety_checker.apply(text2img_pipe)
|
504 |
safety_checker.apply(img2img_pipe)
|
505 |
+
safety_checker.apply(inpainter)
|
506 |
elif task.get_type() == TaskType.REPLACE_BG:
|
507 |
+
replace_background.load(inpainter=inpainter, high_res=high_res)
|
508 |
else:
|
509 |
if task.get_type() == TaskType.TILE_UPSCALE:
|
510 |
controlnet.load_tile_upscaler()
|
|
|
523 |
def model_fn(model_dir):
|
524 |
print("Logs: model loaded .... starts")
|
525 |
|
526 |
+
config = load_model_from_config(model_dir)
|
527 |
+
set_model_config(config)
|
528 |
set_root_dir(__file__)
|
529 |
|
530 |
FailureHandler.register()
|
inference2.py
CHANGED
@@ -23,9 +23,10 @@ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
|
|
23 |
from internals.util.commons import (construct_default_s3_url, upload_image,
|
24 |
upload_images)
|
25 |
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
26 |
-
|
27 |
from internals.util.failure_hander import FailureHandler
|
28 |
from internals.util.lora_style import LoraStyle
|
|
|
29 |
from internals.util.slack import Slack
|
30 |
|
31 |
torch.backends.cudnn.benchmark = True
|
@@ -214,7 +215,8 @@ def upscale_image(task: Task):
|
|
214 |
def model_fn(model_dir):
|
215 |
print("Logs: model loaded .... starts")
|
216 |
|
217 |
-
|
|
|
218 |
set_root_dir(__file__)
|
219 |
|
220 |
FailureHandler.register()
|
|
|
23 |
from internals.util.commons import (construct_default_s3_url, upload_image,
|
24 |
upload_images)
|
25 |
from internals.util.config import (num_return_sequences, set_configs_from_task,
|
26 |
+
set_model_config, set_root_dir)
|
27 |
from internals.util.failure_hander import FailureHandler
|
28 |
from internals.util.lora_style import LoraStyle
|
29 |
+
from internals.util.model_loader import load_model_from_config
|
30 |
from internals.util.slack import Slack
|
31 |
|
32 |
torch.backends.cudnn.benchmark = True
|
|
|
215 |
def model_fn(model_dir):
|
216 |
print("Logs: model loaded .... starts")
|
217 |
|
218 |
+
config = load_model_from_config(model_dir)
|
219 |
+
set_model_config(config)
|
220 |
set_root_dir(__file__)
|
221 |
|
222 |
FailureHandler.register()
|
internals/pipelines/controlnets.py
CHANGED
@@ -4,15 +4,11 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
7 |
-
from diffusers import (
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
)
|
13 |
-
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
|
14 |
-
MultiControlNetModel,
|
15 |
-
)
|
16 |
from PIL import Image
|
17 |
from torch.nn import Linear
|
18 |
from tqdm import gui
|
@@ -22,9 +18,8 @@ import internals.util.image as ImageUtil
|
|
22 |
from external.midas import apply_midas
|
23 |
from internals.data.result import Result
|
24 |
from internals.pipelines.commons import AbstractPipeline
|
25 |
-
from internals.pipelines.tileUpscalePipeline import
|
26 |
-
StableDiffusionControlNetImg2ImgPipeline
|
27 |
-
)
|
28 |
from internals.util.cache import clear_cuda_and_gc
|
29 |
from internals.util.commons import download_image
|
30 |
from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
|
@@ -86,7 +81,7 @@ class ControlNet(AbstractPipeline):
|
|
86 |
if self.__current_task_name == "pose":
|
87 |
return
|
88 |
pose = ControlNetModel.from_pretrained(
|
89 |
-
"lllyasviel/
|
90 |
torch_dtype=torch.float16,
|
91 |
cache_dir=get_hf_cache_dir(),
|
92 |
).to("cuda")
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
|
7 |
+
from diffusers import (ControlNetModel, DiffusionPipeline,
|
8 |
+
StableDiffusionControlNetPipeline,
|
9 |
+
UniPCMultistepScheduler)
|
10 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import \
|
11 |
+
MultiControlNetModel
|
|
|
|
|
|
|
|
|
12 |
from PIL import Image
|
13 |
from torch.nn import Linear
|
14 |
from tqdm import gui
|
|
|
18 |
from external.midas import apply_midas
|
19 |
from internals.data.result import Result
|
20 |
from internals.pipelines.commons import AbstractPipeline
|
21 |
+
from internals.pipelines.tileUpscalePipeline import \
|
22 |
+
StableDiffusionControlNetImg2ImgPipeline
|
|
|
23 |
from internals.util.cache import clear_cuda_and_gc
|
24 |
from internals.util.commons import download_image
|
25 |
from internals.util.config import get_hf_cache_dir, get_hf_token, get_model_dir
|
|
|
81 |
if self.__current_task_name == "pose":
|
82 |
return
|
83 |
pose = ControlNetModel.from_pretrained(
|
84 |
+
"lllyasviel/control_v11p_sd15_openpose",
|
85 |
torch_dtype=torch.float16,
|
86 |
cache_dir=get_hf_cache_dir(),
|
87 |
).to("cuda")
|
internals/pipelines/inpainter.py
CHANGED
@@ -5,18 +5,28 @@ from diffusers import StableDiffusionInpaintPipeline
|
|
5 |
|
6 |
from internals.pipelines.commons import AbstractPipeline
|
7 |
from internals.util.commons import disable_safety_checker, download_image
|
8 |
-
from internals.util.config import get_hf_cache_dir
|
|
|
9 |
|
10 |
|
11 |
class InPainter(AbstractPipeline):
|
|
|
|
|
12 |
def load(self):
|
|
|
|
|
|
|
13 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
14 |
-
|
15 |
torch_dtype=torch.float16,
|
16 |
cache_dir=get_hf_cache_dir(),
|
|
|
17 |
).to("cuda")
|
|
|
18 |
disable_safety_checker(self.pipe)
|
19 |
|
|
|
|
|
20 |
def create(self, pipeline: AbstractPipeline):
|
21 |
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
22 |
"cuda"
|
|
|
5 |
|
6 |
from internals.pipelines.commons import AbstractPipeline
|
7 |
from internals.util.commons import disable_safety_checker, download_image
|
8 |
+
from internals.util.config import (get_hf_cache_dir, get_hf_token,
|
9 |
+
get_inpaint_model_path)
|
10 |
|
11 |
|
12 |
class InPainter(AbstractPipeline):
|
13 |
+
__loaded = False
|
14 |
+
|
15 |
def load(self):
|
16 |
+
if self.__loaded:
|
17 |
+
return
|
18 |
+
|
19 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
20 |
+
get_inpaint_model_path(),
|
21 |
torch_dtype=torch.float16,
|
22 |
cache_dir=get_hf_cache_dir(),
|
23 |
+
use_auth_token=get_hf_token(),
|
24 |
).to("cuda")
|
25 |
+
|
26 |
disable_safety_checker(self.pipe)
|
27 |
|
28 |
+
self.__loaded = True
|
29 |
+
|
30 |
def create(self, pipeline: AbstractPipeline):
|
31 |
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to(
|
32 |
"cuda"
|
internals/pipelines/replace_background.py
CHANGED
@@ -2,6 +2,7 @@ from io import BytesIO
|
|
2 |
from typing import List, Optional, Union
|
3 |
|
4 |
import torch
|
|
|
5 |
from diffusers import (ControlNetModel,
|
6 |
StableDiffusionControlNetInpaintPipeline,
|
7 |
StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
|
@@ -12,10 +13,12 @@ from internals.data.result import Result
|
|
12 |
from internals.pipelines.commons import AbstractPipeline
|
13 |
from internals.pipelines.controlnets import ControlNet
|
14 |
from internals.pipelines.high_res import HighRes
|
|
|
15 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
16 |
from internals.pipelines.upscaler import Upscaler
|
17 |
from internals.util.commons import download_image
|
18 |
-
from internals.util.config import get_hf_cache_dir,
|
|
|
19 |
|
20 |
|
21 |
class ReplaceBackground(AbstractPipeline):
|
@@ -25,7 +28,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
25 |
self,
|
26 |
upscaler: Optional[Upscaler] = None,
|
27 |
remove_background: Optional[RemoveBackgroundV2] = None,
|
28 |
-
|
29 |
high_res: Optional[HighRes] = None,
|
30 |
):
|
31 |
if self.__loaded:
|
@@ -35,18 +38,19 @@ class ReplaceBackground(AbstractPipeline):
|
|
35 |
torch_dtype=torch.float16,
|
36 |
cache_dir=get_hf_cache_dir(),
|
37 |
).to("cuda")
|
38 |
-
if
|
39 |
-
|
40 |
pipe = StableDiffusionControlNetInpaintPipeline(
|
41 |
-
**
|
|
|
42 |
)
|
43 |
-
pipe.controlnet = controlnet_model
|
44 |
else:
|
45 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
46 |
"runwayml/stable-diffusion-inpainting",
|
47 |
controlnet=controlnet_model,
|
48 |
torch_dtype=torch.float16,
|
49 |
cache_dir=get_hf_cache_dir(),
|
|
|
50 |
)
|
51 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
52 |
pipe.to("cuda")
|
@@ -104,14 +108,14 @@ class ReplaceBackground(AbstractPipeline):
|
|
104 |
|
105 |
print(width, height, n_width, n_height)
|
106 |
|
|
|
107 |
if extend_object:
|
108 |
-
condition_image = ControlNet.linearart_condition_image(image)
|
109 |
-
|
110 |
-
)
|
111 |
condition_image = ImageUtil.padd_image(condition_image, width, height)
|
112 |
condition_image = condition_image.convert("RGB")
|
113 |
|
114 |
-
image =
|
115 |
image = ImageUtil.padd_image(image, width, height)
|
116 |
|
117 |
mask = image.copy()
|
@@ -130,46 +134,20 @@ class ReplaceBackground(AbstractPipeline):
|
|
130 |
condition_image = ControlNet.linearart_condition_image(image)
|
131 |
mask = mask.convert("RGB")
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
)
|
148 |
-
for i, _ in enumerate(result.images):
|
149 |
-
out_bytes = self.upscaler.upscale(
|
150 |
-
image=result.images[i],
|
151 |
-
width=w,
|
152 |
-
height=h,
|
153 |
-
face_enhance=False,
|
154 |
-
resize_dimension=max(width, height),
|
155 |
-
)
|
156 |
-
result.images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")
|
157 |
-
result = Result.from_result(result)
|
158 |
-
else:
|
159 |
-
result = self.pipe.__call__(
|
160 |
-
prompt=prompt,
|
161 |
-
negative_prompt=negative_prompt,
|
162 |
-
image=image,
|
163 |
-
mask_image=mask,
|
164 |
-
control_image=condition_image,
|
165 |
-
controlnet_conditioning_scale=conditioning_scale,
|
166 |
-
guidance_scale=9,
|
167 |
-
strength=1,
|
168 |
-
height=height,
|
169 |
-
num_inference_steps=steps,
|
170 |
-
width=width,
|
171 |
-
)
|
172 |
-
result = Result.from_result(result)
|
173 |
|
174 |
images, has_nsfw = result
|
175 |
|
|
|
2 |
from typing import List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
+
from cv2 import inpaint
|
6 |
from diffusers import (ControlNetModel,
|
7 |
StableDiffusionControlNetInpaintPipeline,
|
8 |
StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
|
|
|
13 |
from internals.pipelines.commons import AbstractPipeline
|
14 |
from internals.pipelines.controlnets import ControlNet
|
15 |
from internals.pipelines.high_res import HighRes
|
16 |
+
from internals.pipelines.inpainter import InPainter
|
17 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
18 |
from internals.pipelines.upscaler import Upscaler
|
19 |
from internals.util.commons import download_image
|
20 |
+
from internals.util.config import (get_hf_cache_dir, get_hf_token,
|
21 |
+
get_inpaint_model_path, get_model_dir)
|
22 |
|
23 |
|
24 |
class ReplaceBackground(AbstractPipeline):
|
|
|
28 |
self,
|
29 |
upscaler: Optional[Upscaler] = None,
|
30 |
remove_background: Optional[RemoveBackgroundV2] = None,
|
31 |
+
inpainter: Optional[InPainter] = None,
|
32 |
high_res: Optional[HighRes] = None,
|
33 |
):
|
34 |
if self.__loaded:
|
|
|
38 |
torch_dtype=torch.float16,
|
39 |
cache_dir=get_hf_cache_dir(),
|
40 |
).to("cuda")
|
41 |
+
if inpainter:
|
42 |
+
inpainter.load()
|
43 |
pipe = StableDiffusionControlNetInpaintPipeline(
|
44 |
+
**inpainter.pipe.components,
|
45 |
+
controlnet=controlnet_model,
|
46 |
)
|
|
|
47 |
else:
|
48 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
49 |
"runwayml/stable-diffusion-inpainting",
|
50 |
controlnet=controlnet_model,
|
51 |
torch_dtype=torch.float16,
|
52 |
cache_dir=get_hf_cache_dir(),
|
53 |
+
use_auth_token=get_hf_token(),
|
54 |
)
|
55 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
56 |
pipe.to("cuda")
|
|
|
108 |
|
109 |
print(width, height, n_width, n_height)
|
110 |
|
111 |
+
resolution = min(n_width, n_height)
|
112 |
if extend_object:
|
113 |
+
condition_image = ControlNet.linearart_condition_image(image)
|
114 |
+
condition_image = ImageUtil.resize_image(condition_image, resolution)
|
|
|
115 |
condition_image = ImageUtil.padd_image(condition_image, width, height)
|
116 |
condition_image = condition_image.convert("RGB")
|
117 |
|
118 |
+
image = ImageUtil.resize_image(image, resolution)
|
119 |
image = ImageUtil.padd_image(image, width, height)
|
120 |
|
121 |
mask = image.copy()
|
|
|
134 |
condition_image = ControlNet.linearart_condition_image(image)
|
135 |
mask = mask.convert("RGB")
|
136 |
|
137 |
+
result = self.pipe.__call__(
|
138 |
+
prompt=prompt,
|
139 |
+
negative_prompt=negative_prompt,
|
140 |
+
image=image,
|
141 |
+
mask_image=mask,
|
142 |
+
control_image=condition_image,
|
143 |
+
controlnet_conditioning_scale=conditioning_scale,
|
144 |
+
guidance_scale=9,
|
145 |
+
strength=1,
|
146 |
+
height=height,
|
147 |
+
num_inference_steps=steps,
|
148 |
+
width=width,
|
149 |
+
)
|
150 |
+
result = Result.from_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
images, has_nsfw = result
|
153 |
|
internals/util/config.py
CHANGED
@@ -3,13 +3,14 @@ from pathlib import Path
|
|
3 |
from typing import Union
|
4 |
|
5 |
from internals.data.task import Task
|
|
|
6 |
|
7 |
env = "prod"
|
8 |
nsfw_threshold = 0.0
|
9 |
nsfw_access = False
|
10 |
access_token = ""
|
11 |
root_dir = ""
|
12 |
-
|
13 |
hf_token = "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"
|
14 |
hf_cache_dir = "/tmp/hf_hub"
|
15 |
|
@@ -28,16 +29,16 @@ def get_hf_cache_dir():
|
|
28 |
return hf_cache_dir
|
29 |
|
30 |
|
31 |
-
def set_model_dir(dir: str):
|
32 |
-
global model_dir
|
33 |
-
model_dir = dir
|
34 |
-
|
35 |
-
|
36 |
def set_root_dir(main_file: str):
|
37 |
global root_dir
|
38 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
41 |
def set_configs_from_task(task: Task):
|
42 |
global env, nsfw_threshold, nsfw_access, access_token
|
43 |
name = task.get_queue_name()
|
@@ -51,8 +52,13 @@ def set_configs_from_task(task: Task):
|
|
51 |
|
52 |
|
53 |
def get_model_dir():
|
54 |
-
global
|
55 |
-
return
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def get_root_dir():
|
|
|
3 |
from typing import Union
|
4 |
|
5 |
from internals.data.task import Task
|
6 |
+
from internals.util.model_loader import ModelConfig
|
7 |
|
8 |
env = "prod"
|
9 |
nsfw_threshold = 0.0
|
10 |
nsfw_access = False
|
11 |
access_token = ""
|
12 |
root_dir = ""
|
13 |
+
model_config = None
|
14 |
hf_token = "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"
|
15 |
hf_cache_dir = "/tmp/hf_hub"
|
16 |
|
|
|
29 |
return hf_cache_dir
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
def set_root_dir(main_file: str):
|
33 |
global root_dir
|
34 |
root_dir = os.path.dirname(os.path.abspath(main_file))
|
35 |
|
36 |
|
37 |
+
def set_model_config(config: ModelConfig):
|
38 |
+
global model_config
|
39 |
+
model_config = config
|
40 |
+
|
41 |
+
|
42 |
def set_configs_from_task(task: Task):
|
43 |
global env, nsfw_threshold, nsfw_access, access_token
|
44 |
name = task.get_queue_name()
|
|
|
52 |
|
53 |
|
54 |
def get_model_dir():
|
55 |
+
global model_config
|
56 |
+
return model_config.base_model_path # pyright: ignore
|
57 |
+
|
58 |
+
|
59 |
+
def get_inpaint_model_path():
|
60 |
+
global model_config
|
61 |
+
return model_config.base_inpaint_model_path # pyright: ignore
|
62 |
|
63 |
|
64 |
def get_root_dir():
|
internals/util/model_loader.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
from threading import Thread
|
7 |
+
from typing import Any, Dict, List, Optional
|
8 |
+
|
9 |
+
import requests
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class ModelConfig:
|
15 |
+
base_model_path: str
|
16 |
+
base_inpaint_model_path: str
|
17 |
+
|
18 |
+
|
19 |
+
def load_model_from_config(path):
|
20 |
+
m_config = ModelConfig(path, path)
|
21 |
+
if os.path.exists(path + "/inference.json"):
|
22 |
+
with open(path + "/inference.json", "r") as f:
|
23 |
+
config = json.loads(f.read())
|
24 |
+
model_path = config.get("model_path", path)
|
25 |
+
inpaint_model_path = config.get("inpaint_model_path", path)
|
26 |
+
|
27 |
+
m_config.base_model_path = model_path
|
28 |
+
m_config.base_inpaint_model_path = inpaint_model_path
|
29 |
+
|
30 |
+
#
|
31 |
+
# if config.get("model_type") == "huggingface":
|
32 |
+
# model_dir = config["model_path"]
|
33 |
+
# if config.get("model_type") == "s3":
|
34 |
+
# s3_config = config["model_path"]["s3"]
|
35 |
+
# base_url = s3_config["base_url"]
|
36 |
+
#
|
37 |
+
# urls = [base_url + item for item in s3_config["paths"]]
|
38 |
+
# out_dir = Path.home() / ".cache" / "base_model"
|
39 |
+
# if out_dir.exists():
|
40 |
+
# print("Model already exist")
|
41 |
+
# else:
|
42 |
+
# print("Downloading model")
|
43 |
+
# BaseModelDownloader(urls, s3_config["paths"], out_dir).download()
|
44 |
+
# model_dir = str(out_dir)
|
45 |
+
return m_config
|
46 |
+
|
47 |
+
|
48 |
+
class BaseModelDownloader:
|
49 |
+
"""
|
50 |
+
A utility for fast download of base model from S3 or any CDN served storage.
|
51 |
+
Works by downloading multiple files in parallel and dividing large files
|
52 |
+
into smaller chunks and combining them at the end.
|
53 |
+
|
54 |
+
Currently it uses multithreading (not multiprocessing) assuming GIL won't
|
55 |
+
interfere with network/disk IO.
|
56 |
+
|
57 |
+
Created by: KP
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path):
|
61 |
+
self.urls = urls
|
62 |
+
self.url_paths = url_paths
|
63 |
+
shutil.rmtree(out_dir, ignore_errors=True)
|
64 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
65 |
+
self.out_dir = out_dir
|
66 |
+
|
67 |
+
def download(self):
|
68 |
+
threads = []
|
69 |
+
batch_urls = {}
|
70 |
+
|
71 |
+
for url, url_path in zip(self.urls, self.url_paths):
|
72 |
+
out_dir = self.out_dir / url_path
|
73 |
+
self.out_dir.parent.mkdir(parents=True, exist_ok=True)
|
74 |
+
if url.endswith(".bin"):
|
75 |
+
if "unet/" in url_path:
|
76 |
+
thread = Thread(
|
77 |
+
target=self.__download_parallel, args=(url, out_dir, 6)
|
78 |
+
)
|
79 |
+
thread.start()
|
80 |
+
threads.append(thread)
|
81 |
+
else:
|
82 |
+
thread = Thread(
|
83 |
+
target=self.__download_files, args=([url], [out_dir])
|
84 |
+
)
|
85 |
+
thread.start()
|
86 |
+
threads.append(thread)
|
87 |
+
pass
|
88 |
+
else:
|
89 |
+
batch_urls[url] = out_dir
|
90 |
+
|
91 |
+
if batch_urls:
|
92 |
+
thread = Thread(
|
93 |
+
target=self.__download_files,
|
94 |
+
args=(list(batch_urls.keys()), list(batch_urls.values())),
|
95 |
+
)
|
96 |
+
thread.start()
|
97 |
+
threads.append(thread)
|
98 |
+
pass
|
99 |
+
|
100 |
+
for thread in threads:
|
101 |
+
thread.join()
|
102 |
+
|
103 |
+
def __download_parallel(self, url, output_filename, num_parts=4):
|
104 |
+
response = requests.head(url)
|
105 |
+
total_size = int(response.headers.get("content-length", 0))
|
106 |
+
print("total_size", total_size)
|
107 |
+
|
108 |
+
chunk_size = total_size // num_parts
|
109 |
+
ranges = [
|
110 |
+
(i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1)
|
111 |
+
]
|
112 |
+
ranges.append((ranges[-1][1] + 1, total_size))
|
113 |
+
|
114 |
+
print(ranges)
|
115 |
+
|
116 |
+
save_dir = Path.home() / ".cache" / "download_parts"
|
117 |
+
os.makedirs(save_dir, exist_ok=True)
|
118 |
+
|
119 |
+
threads = []
|
120 |
+
for i, (start, end) in enumerate(ranges):
|
121 |
+
thread = Thread(
|
122 |
+
target=self.__download_part, args=(url, start, end, i, save_dir)
|
123 |
+
)
|
124 |
+
thread.start()
|
125 |
+
threads.append(thread)
|
126 |
+
|
127 |
+
for thread in threads:
|
128 |
+
thread.join()
|
129 |
+
|
130 |
+
self.__combine_parts(save_dir, output_filename, num_parts)
|
131 |
+
os.rmdir(save_dir)
|
132 |
+
|
133 |
+
def __combine_parts(self, save_dir, output_filename, num_parts):
|
134 |
+
part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)]
|
135 |
+
|
136 |
+
output_filename.parent.mkdir(parents=True, exist_ok=True)
|
137 |
+
with open(output_filename, "wb") as output_file:
|
138 |
+
for part_file in part_files:
|
139 |
+
print("combining: ", part_file)
|
140 |
+
with open(part_file, "rb") as part:
|
141 |
+
output_file.write(part.read())
|
142 |
+
|
143 |
+
out_file_size = output_file.tell()
|
144 |
+
print("out_file_size", out_file_size)
|
145 |
+
|
146 |
+
for part_file in part_files:
|
147 |
+
os.remove(part_file)
|
148 |
+
|
149 |
+
def __download_part(self, url, start_byte, end_byte, part_num, save_dir):
|
150 |
+
headers = {"Range": f"bytes={start_byte}-{end_byte}"}
|
151 |
+
response = requests.get(url, headers=headers, stream=True)
|
152 |
+
|
153 |
+
part_filename = os.path.join(save_dir, f"part_{part_num}.tmp")
|
154 |
+
print("Downloading part: ", url, part_filename, end_byte - start_byte)
|
155 |
+
|
156 |
+
with open(part_filename, "wb") as part_file, tqdm(
|
157 |
+
desc=str(part_filename),
|
158 |
+
total=end_byte - start_byte,
|
159 |
+
unit="B",
|
160 |
+
unit_scale=True,
|
161 |
+
unit_divisor=1024,
|
162 |
+
) as bar:
|
163 |
+
for chunk in response.iter_content(chunk_size=8192):
|
164 |
+
if chunk:
|
165 |
+
size = part_file.write(chunk)
|
166 |
+
bar.update(size)
|
167 |
+
|
168 |
+
return part_filename
|
169 |
+
|
170 |
+
def __download_files(self, urls, out_paths: List[Path]):
|
171 |
+
for url, out_path in zip(urls, out_paths):
|
172 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
173 |
+
with requests.get(url, stream=True) as r:
|
174 |
+
print("Downloading: ", url)
|
175 |
+
total_size = int(r.headers.get("content-length", 0))
|
176 |
+
chunk_size = 8192
|
177 |
+
r.raise_for_status()
|
178 |
+
with open(out_path, "wb") as f, tqdm(
|
179 |
+
desc=str(out_path),
|
180 |
+
total=total_size,
|
181 |
+
unit="B",
|
182 |
+
unit_scale=True,
|
183 |
+
unit_divisor=1024,
|
184 |
+
) as bar:
|
185 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
186 |
+
size = f.write(data)
|
187 |
+
bar.update(size)
|
requirements.txt
CHANGED
@@ -5,7 +5,7 @@ fastapi==0.87.0
|
|
5 |
Pillow==9.3.0
|
6 |
redis==4.3.4
|
7 |
requests==2.28.1
|
8 |
-
transformers
|
9 |
rembg==2.0.30
|
10 |
gfpgan==1.3.8
|
11 |
rembg==2.0.30
|
|
|
5 |
Pillow==9.3.0
|
6 |
redis==4.3.4
|
7 |
requests==2.28.1
|
8 |
+
transformers==4.34.1
|
9 |
rembg==2.0.30
|
10 |
gfpgan==1.3.8
|
11 |
rembg==2.0.30
|