jayparmr commited on
Commit
b71808f
1 Parent(s): 2b1a525

Upload folder using huggingface_hub

Browse files
handler.py CHANGED
@@ -1,11 +1,38 @@
 
 
 
1
  from typing import Any, Dict, List
2
 
3
  from inference import model_fn, predict_fn
 
4
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
- return model_fn(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
11
  return predict_fn(data, None)
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
  from typing import Any, Dict, List
5
 
6
  from inference import model_fn, predict_fn
7
+ from internals.util.model_downloader import BaseModelDownloader
8
 
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ self.model_dir = path
13
+
14
+ if os.path.exists(path + "/inference.json"):
15
+ with open(path + "/inference.json", "r") as f:
16
+ config = json.loads(f.read())
17
+ if config.get("model_type") == "huggingface":
18
+ self.model_dir = config["model_path"]
19
+ if config.get("model_type") == "s3":
20
+ s3_config = config["model_path"]["s3"]
21
+ base_url = s3_config["base_url"]
22
+
23
+ urls = [base_url + item for item in s3_config["paths"]]
24
+ out_dir = Path.home() / ".cache" / "base_model"
25
+ if out_dir.exists():
26
+ print("Model already exist")
27
+ else:
28
+ print("Downloading model")
29
+ BaseModelDownloader(
30
+ urls, s3_config["paths"], out_dir
31
+ ).download()
32
+
33
+ self.model_dir = str(out_dir)
34
+
35
+ return model_fn(self.model_dir)
36
 
37
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
38
  return predict_fn(data, None)
inference.py CHANGED
@@ -15,18 +15,13 @@ from internals.pipelines.safety_checker import SafetyChecker
15
  from internals.util.anomaly import remove_colors
16
  from internals.util.args import apply_style_args
17
  from internals.util.avatar import Avatar
18
- from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
19
- from internals.util.commons import (
20
- download_image,
21
- pickPoses,
22
- upload_image,
23
- upload_images,
24
- )
25
- from internals.util.config import (
26
- num_return_sequences,
27
- set_configs_from_task,
28
- set_root_dir,
29
- )
30
  from internals.util.failure_hander import FailureHandler
31
  from internals.util.lora_style import LoraStyle
32
  from internals.util.slack import Slack
@@ -442,27 +437,48 @@ def inpaint(task: Task):
442
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
443
 
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  def model_fn(model_dir):
446
  print("Logs: model loaded .... starts")
447
 
 
448
  set_root_dir(__file__)
449
 
450
  FailureHandler.register()
451
 
452
  avatar.load_local(model_dir)
453
 
454
- prompt_modifier.load()
455
- pose_detector.load()
456
- img2text.load()
457
- img_classifier.load()
458
-
459
  lora_style.load(model_dir)
460
- safety_checker.load()
461
-
462
- controlnet.load(model_dir)
463
- text2img_pipe.load(model_dir)
464
- img2img_pipe.create(text2img_pipe)
465
- inpainter.create(text2img_pipe)
466
 
467
  print("Logs: model loaded ....")
468
  return
@@ -479,10 +495,8 @@ def predict_fn(data, pipe):
479
  # Set set_environment
480
  set_configs_from_task(task)
481
 
482
- # Apply safety checkers based on environment
483
- safety_checker.apply(text2img_pipe)
484
- safety_checker.apply(img2img_pipe)
485
- safety_checker.apply(controlnet)
486
 
487
  # Apply arguments
488
  apply_style_args(data)
@@ -497,10 +511,10 @@ def predict_fn(data, pipe):
497
 
498
  if task_type == TaskType.TEXT_TO_IMAGE:
499
  # character sheet
500
- if "character sheet" in task.get_prompt().lower():
501
- return pose(task, s3_outkey="", poses=pickPoses())
502
- else:
503
- return text2img(task)
504
  elif task_type == TaskType.IMAGE_TO_IMAGE:
505
  return img2img(task)
506
  elif task_type == TaskType.CANNY:
 
15
  from internals.util.anomaly import remove_colors
16
  from internals.util.args import apply_style_args
17
  from internals.util.avatar import Avatar
18
+ from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
19
+ clear_cuda_and_gc)
20
+ from internals.util.commons import (download_image, pickPoses, upload_image,
21
+ upload_images)
22
+ from internals.util.config import (get_model_dir, num_return_sequences,
23
+ set_configs_from_task, set_model_dir,
24
+ set_root_dir)
 
 
 
 
 
25
  from internals.util.failure_hander import FailureHandler
26
  from internals.util.lora_style import LoraStyle
27
  from internals.util.slack import Slack
 
437
  return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
438
 
439
 
440
+ def load_model_by_task(task: Task):
441
+ if (
442
+ task.get_type()
443
+ in [
444
+ TaskType.TEXT_TO_IMAGE,
445
+ TaskType.IMAGE_TO_IMAGE,
446
+ TaskType.INPAINT,
447
+ ]
448
+ and not text2img_pipe.is_loaded()
449
+ ):
450
+ text2img_pipe.load(get_model_dir())
451
+ img2img_pipe.create(text2img_pipe)
452
+ inpainter.create(text2img_pipe)
453
+
454
+ safety_checker.apply(text2img_pipe)
455
+ safety_checker.apply(img2img_pipe)
456
+ else:
457
+ if task.get_type() == TaskType.TILE_UPSCALE:
458
+ controlnet.load_tile_upscaler()
459
+ elif task.get_type() == TaskType.CANNY:
460
+ controlnet.load_canny()
461
+ elif task.get_type() == TaskType.SCRIBBLE:
462
+ controlnet.load_scribble()
463
+ elif task.get_type() == TaskType.LINEARART:
464
+ controlnet.load_linearart()
465
+ elif task.get_type() == TaskType.POSE:
466
+ controlnet.load_pose()
467
+
468
+ safety_checker.apply(controlnet)
469
+
470
+
471
  def model_fn(model_dir):
472
  print("Logs: model loaded .... starts")
473
 
474
+ set_model_dir(model_dir)
475
  set_root_dir(__file__)
476
 
477
  FailureHandler.register()
478
 
479
  avatar.load_local(model_dir)
480
 
 
 
 
 
 
481
  lora_style.load(model_dir)
 
 
 
 
 
 
482
 
483
  print("Logs: model loaded ....")
484
  return
 
495
  # Set set_environment
496
  set_configs_from_task(task)
497
 
498
+ # Load model based on task
499
+ load_model_by_task(task)
 
 
500
 
501
  # Apply arguments
502
  apply_style_args(data)
 
511
 
512
  if task_type == TaskType.TEXT_TO_IMAGE:
513
  # character sheet
514
+ # if "character sheet" in task.get_prompt().lower():
515
+ # return pose(task, s3_outkey="", poses=pickPoses())
516
+ # else:
517
+ return text2img(task)
518
  elif task_type == TaskType.IMAGE_TO_IMAGE:
519
  return img2img(task)
520
  elif task_type == TaskType.CANNY:
inference2.py CHANGED
@@ -7,17 +7,18 @@ from internals.data.task import ModelType, Task, TaskType
7
  from internals.pipelines.inpainter import InPainter
8
  from internals.pipelines.object_remove import ObjectRemoval
9
  from internals.pipelines.prompt_modifier import PromptModifier
10
- from internals.pipelines.remove_background import (RemoveBackground,
11
- RemoveBackgroundV2)
12
  from internals.pipelines.replace_background import ReplaceBackground
13
  from internals.pipelines.safety_checker import SafetyChecker
14
  from internals.pipelines.upscaler import Upscaler
15
  from internals.util.avatar import Avatar
16
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
17
- from internals.util.commons import (construct_default_s3_url, upload_image,
18
- upload_images)
19
- from internals.util.config import (num_return_sequences, set_configs_from_task,
20
- set_root_dir)
 
 
21
  from internals.util.failure_hander import FailureHandler
22
  from internals.util.slack import Slack
23
 
@@ -189,6 +190,7 @@ def predict_fn(data, pipe):
189
 
190
  # Apply safety checker based on environment
191
  safety_checker.apply(inpainter)
 
192
 
193
  # Fetch avatars
194
  avatar.fetch_from_network(task.get_model_id())
 
7
  from internals.pipelines.inpainter import InPainter
8
  from internals.pipelines.object_remove import ObjectRemoval
9
  from internals.pipelines.prompt_modifier import PromptModifier
10
+ from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
 
11
  from internals.pipelines.replace_background import ReplaceBackground
12
  from internals.pipelines.safety_checker import SafetyChecker
13
  from internals.pipelines.upscaler import Upscaler
14
  from internals.util.avatar import Avatar
15
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
16
+ from internals.util.commons import construct_default_s3_url, upload_image, upload_images
17
+ from internals.util.config import (
18
+ num_return_sequences,
19
+ set_configs_from_task,
20
+ set_root_dir,
21
+ )
22
  from internals.util.failure_hander import FailureHandler
23
  from internals.util.slack import Slack
24
 
 
190
 
191
  # Apply safety checker based on environment
192
  safety_checker.apply(inpainter)
193
+ safety_checker.apply(replace_background)
194
 
195
  # Fetch avatars
196
  avatar.fetch_from_network(task.get_model_id())
internals/data/task.py CHANGED
@@ -30,7 +30,7 @@ class Task:
30
  def __init__(self, data):
31
  self.__data = data
32
  if data.get("seed", -1) == None or self.get_seed() == -1:
33
- self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
34
  prompt = data.get("prompt", "")
35
  if prompt is None:
36
  self.__data["prompt"] = ""
 
30
  def __init__(self, data):
31
  self.__data = data
32
  if data.get("seed", -1) == None or self.get_seed() == -1:
33
+ self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max)
34
  prompt = data.get("prompt", "")
35
  if prompt is None:
36
  self.__data["prompt"] = ""
internals/pipelines/commons.py CHANGED
@@ -7,7 +7,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
7
  from internals.data.result import Result
8
  from internals.pipelines.twoStepPipeline import two_step_pipeline
9
  from internals.util.commons import disable_safety_checker, download_image
10
- from internals.util.config import num_return_sequences
11
 
12
 
13
  class AbstractPipeline:
@@ -28,10 +28,15 @@ class Text2Img(AbstractPipeline):
28
 
29
  def load(self, model_dir: str):
30
  self.pipe = two_step_pipeline.from_pretrained(
31
- model_dir, torch_dtype=torch.float16
32
  ).to("cuda")
33
  self.__patch()
34
 
 
 
 
 
 
35
  def create(self, pipeline: AbstractPipeline):
36
  self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
37
  self.__patch()
@@ -115,7 +120,7 @@ class Text2Img(AbstractPipeline):
115
  class Img2Img(AbstractPipeline):
116
  def load(self, model_dir: str):
117
  self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
118
- model_dir, torch_dtype=torch.float16
119
  ).to("cuda")
120
  self.__patch()
121
 
 
7
  from internals.data.result import Result
8
  from internals.pipelines.twoStepPipeline import two_step_pipeline
9
  from internals.util.commons import disable_safety_checker, download_image
10
+ from internals.util.config import get_hf_token, num_return_sequences
11
 
12
 
13
  class AbstractPipeline:
 
28
 
29
  def load(self, model_dir: str):
30
  self.pipe = two_step_pipeline.from_pretrained(
31
+ model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
32
  ).to("cuda")
33
  self.__patch()
34
 
35
+ def is_loaded(self):
36
+ if hasattr(self, "pipe"):
37
+ return True
38
+ return False
39
+
40
  def create(self, pipeline: AbstractPipeline):
41
  self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
42
  self.__patch()
 
120
  class Img2Img(AbstractPipeline):
121
  def load(self, model_dir: str):
122
  self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
123
+ model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
124
  ).to("cuda")
125
  self.__patch()
126
 
internals/pipelines/controlnets.py CHANGED
@@ -4,33 +4,43 @@ import cv2
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
- from diffusers import (ControlNetModel, DiffusionPipeline,
8
- StableDiffusionControlNetPipeline,
9
- UniPCMultistepScheduler)
 
 
 
10
  from PIL import Image
11
  from torch.nn import Linear
12
  from tqdm import gui
13
 
14
  from internals.data.result import Result
15
  from internals.pipelines.commons import AbstractPipeline
16
- from internals.pipelines.tileUpscalePipeline import \
17
- StableDiffusionControlNetImg2ImgPipeline
 
18
  from internals.util.cache import clear_cuda_and_gc
19
  from internals.util.commons import download_image
 
20
 
21
 
22
  class ControlNet(AbstractPipeline):
23
  __current_task_name = ""
 
24
 
25
- def load(self, model_dir: str):
26
- # we will load canny by default
27
- self.load_scribble()
 
 
 
28
 
29
  # controlnet pipeline for tile upscaler
30
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
31
- model_dir,
32
  controlnet=self.controlnet,
33
  torch_dtype=torch.float16,
 
34
  ).to("cuda")
35
  # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
36
  pipe.enable_model_cpu_offload()
@@ -43,6 +53,8 @@ class ControlNet(AbstractPipeline):
43
  pipe2.enable_xformers_memory_efficient_attention()
44
  self.pipe2 = pipe2
45
 
 
 
46
  def load_canny(self):
47
  if self.__current_task_name == "canny":
48
  return
@@ -51,6 +63,9 @@ class ControlNet(AbstractPipeline):
51
  ).to("cuda")
52
  self.__current_task_name = "canny"
53
  self.controlnet = canny
 
 
 
54
  if hasattr(self, "pipe"):
55
  self.pipe.controlnet = canny
56
  if hasattr(self, "pipe2"):
@@ -65,6 +80,9 @@ class ControlNet(AbstractPipeline):
65
  ).to("cuda")
66
  self.__current_task_name = "pose"
67
  self.controlnet = pose
 
 
 
68
  if hasattr(self, "pipe"):
69
  self.pipe.controlnet = pose
70
  if hasattr(self, "pipe2"):
@@ -79,6 +97,9 @@ class ControlNet(AbstractPipeline):
79
  ).to("cuda")
80
  self.__current_task_name = "tile_upscaler"
81
  self.controlnet = tile_upscaler
 
 
 
82
  if hasattr(self, "pipe"):
83
  self.pipe.controlnet = tile_upscaler
84
  if hasattr(self, "pipe2"):
@@ -93,6 +114,9 @@ class ControlNet(AbstractPipeline):
93
  ).to("cuda")
94
  self.__current_task_name = "scribble"
95
  self.controlnet = scribble
 
 
 
96
  if hasattr(self, "pipe"):
97
  self.pipe.controlnet = scribble
98
  if hasattr(self, "pipe2"):
@@ -108,6 +132,9 @@ class ControlNet(AbstractPipeline):
108
  ).to("cuda")
109
  self.__current_task_name = "linearart"
110
  self.controlnet = linearart
 
 
 
111
  if hasattr(self, "pipe"):
112
  self.pipe.controlnet = linearart
113
  if hasattr(self, "pipe2"):
 
4
  import numpy as np
5
  import torch
6
  from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ DiffusionPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ UniPCMultistepScheduler,
12
+ )
13
  from PIL import Image
14
  from torch.nn import Linear
15
  from tqdm import gui
16
 
17
  from internals.data.result import Result
18
  from internals.pipelines.commons import AbstractPipeline
19
+ from internals.pipelines.tileUpscalePipeline import (
20
+ StableDiffusionControlNetImg2ImgPipeline,
21
+ )
22
  from internals.util.cache import clear_cuda_and_gc
23
  from internals.util.commons import download_image
24
+ from internals.util.config import get_hf_token, get_model_dir
25
 
26
 
27
  class ControlNet(AbstractPipeline):
28
  __current_task_name = ""
29
+ __loaded = False
30
 
31
+ def load(self):
32
+ if self.__loaded:
33
+ return
34
+
35
+ if not self.controlnet:
36
+ self.load_pose()
37
 
38
  # controlnet pipeline for tile upscaler
39
  pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
40
+ get_model_dir(),
41
  controlnet=self.controlnet,
42
  torch_dtype=torch.float16,
43
+ use_auth_token=get_hf_token(),
44
  ).to("cuda")
45
  # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
46
  pipe.enable_model_cpu_offload()
 
53
  pipe2.enable_xformers_memory_efficient_attention()
54
  self.pipe2 = pipe2
55
 
56
+ self.__loaded = True
57
+
58
  def load_canny(self):
59
  if self.__current_task_name == "canny":
60
  return
 
63
  ).to("cuda")
64
  self.__current_task_name = "canny"
65
  self.controlnet = canny
66
+
67
+ self.load()
68
+
69
  if hasattr(self, "pipe"):
70
  self.pipe.controlnet = canny
71
  if hasattr(self, "pipe2"):
 
80
  ).to("cuda")
81
  self.__current_task_name = "pose"
82
  self.controlnet = pose
83
+
84
+ self.load()
85
+
86
  if hasattr(self, "pipe"):
87
  self.pipe.controlnet = pose
88
  if hasattr(self, "pipe2"):
 
97
  ).to("cuda")
98
  self.__current_task_name = "tile_upscaler"
99
  self.controlnet = tile_upscaler
100
+
101
+ self.load()
102
+
103
  if hasattr(self, "pipe"):
104
  self.pipe.controlnet = tile_upscaler
105
  if hasattr(self, "pipe2"):
 
114
  ).to("cuda")
115
  self.__current_task_name = "scribble"
116
  self.controlnet = scribble
117
+
118
+ self.load()
119
+
120
  if hasattr(self, "pipe"):
121
  self.pipe.controlnet = scribble
122
  if hasattr(self, "pipe2"):
 
132
  ).to("cuda")
133
  self.__current_task_name = "linearart"
134
  self.controlnet = linearart
135
+
136
+ self.load()
137
+
138
  if hasattr(self, "pipe"):
139
  self.pipe.controlnet = linearart
140
  if hasattr(self, "pipe2"):
internals/pipelines/img_classifier.py CHANGED
@@ -6,16 +6,24 @@ from internals.util.commons import download_image
6
 
7
 
8
  class ImageClassifier:
 
 
9
  def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
10
  self.__candidates = candidates
11
 
12
  def load(self):
 
 
 
13
  self.pipe = pipeline(
14
  "zero-shot-image-classification",
15
  model="philschmid/clip-zero-shot-image-classification",
16
  )
17
 
 
 
18
  def classify(self, image_url: str, width: int, height: int) -> str:
 
19
  image = download_image(image_url).resize((width, height))
20
  results = self.pipe.__call__([image], candidate_labels=self.__candidates)
21
  results = results[0]
 
6
 
7
 
8
  class ImageClassifier:
9
+ __loaded = False
10
+
11
  def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
12
  self.__candidates = candidates
13
 
14
  def load(self):
15
+ if self.__loaded:
16
+ return
17
+
18
  self.pipe = pipeline(
19
  "zero-shot-image-classification",
20
  model="philschmid/clip-zero-shot-image-classification",
21
  )
22
 
23
+ self.__loaded = True
24
+
25
  def classify(self, image_url: str, width: int, height: int) -> str:
26
+ self.load()
27
  image = download_image(image_url).resize((width, height))
28
  results = self.pipe.__call__([image], candidate_labels=self.__candidates)
29
  results = results[0]
internals/pipelines/img_to_text.py CHANGED
@@ -8,7 +8,12 @@ from internals.util.commons import download_image
8
 
9
 
10
  class Image2Text:
 
 
11
  def load(self):
 
 
 
12
  self.processor = BlipProcessor.from_pretrained(
13
  "Salesforce/blip-image-captioning-large"
14
  )
@@ -16,7 +21,10 @@ class Image2Text:
16
  "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
17
  ).to("cuda")
18
 
 
 
19
  def process(self, imageUrl: str) -> str:
 
20
  image = download_image(imageUrl).resize((512, 512))
21
  inputs = self.processor.__call__(image, return_tensors="pt").to(
22
  "cuda", torch.float16
 
8
 
9
 
10
  class Image2Text:
11
+ __loaded = False
12
+
13
  def load(self):
14
+ if self.__loaded:
15
+ return
16
+
17
  self.processor = BlipProcessor.from_pretrained(
18
  "Salesforce/blip-image-captioning-large"
19
  )
 
21
  "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
22
  ).to("cuda")
23
 
24
+ self.__loaded = True
25
+
26
  def process(self, imageUrl: str) -> str:
27
+ self.load()
28
  image = download_image(imageUrl).resize((512, 512))
29
  inputs = self.processor.__call__(image, return_tensors="pt").to(
30
  "cuda", torch.float16
internals/pipelines/pose_detector.py CHANGED
@@ -36,6 +36,7 @@ class PoseDetector:
36
  client_coordinates: Optional[dict],
37
  ) -> Image.Image:
38
  "Infer pose coordinates from image, map head and body coordinates to infered ones, create pose"
 
39
  if type(image) is str:
40
  image = download_image(image)
41
 
@@ -103,6 +104,7 @@ class PoseDetector:
103
  return image
104
 
105
  def infer(self, image: Union[str, Image.Image], width, height) -> dict:
 
106
  candidate = []
107
  subset = []
108
 
 
36
  client_coordinates: Optional[dict],
37
  ) -> Image.Image:
38
  "Infer pose coordinates from image, map head and body coordinates to infered ones, create pose"
39
+ self.load()
40
  if type(image) is str:
41
  image = download_image(image)
42
 
 
104
  return image
105
 
106
  def infer(self, image: Union[str, Image.Image], width, height) -> dict:
107
+ self.load()
108
  candidate = []
109
  subset = []
110
 
internals/pipelines/prompt_modifier.py CHANGED
@@ -4,11 +4,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
 
5
 
6
  class PromptModifier:
 
 
7
  def __init__(self, num_of_sequences: Optional[int] = 4):
8
  self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
9
  self.__num_of_sequences = num_of_sequences
10
 
11
  def load(self):
 
 
12
  self.prompter_model = AutoModelForCausalLM.from_pretrained(
13
  "Gustavosta/MagicPrompt-Stable-Diffusion"
14
  )
@@ -18,7 +22,10 @@ class PromptModifier:
18
  self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
19
  self.prompter_tokenizer.padding_side = "left"
20
 
 
 
21
  def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
 
22
  eos_id = self.prompter_tokenizer.eos_token_id
23
  # restricted_words_list = ["octane", "cyber"]
24
  # restricted_words_token_ids = prompter_tokenizer(
 
4
 
5
 
6
  class PromptModifier:
7
+ __loaded = False
8
+
9
  def __init__(self, num_of_sequences: Optional[int] = 4):
10
  self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
11
  self.__num_of_sequences = num_of_sequences
12
 
13
  def load(self):
14
+ if self.__loaded:
15
+ return
16
  self.prompter_model = AutoModelForCausalLM.from_pretrained(
17
  "Gustavosta/MagicPrompt-Stable-Diffusion"
18
  )
 
22
  self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
23
  self.prompter_tokenizer.padding_side = "left"
24
 
25
+ self.__loaded = True
26
+
27
  def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
28
+ self.load()
29
  eos_id = self.prompter_tokenizer.eos_token_id
30
  # restricted_words_list = ["octane", "cyber"]
31
  # restricted_words_token_ids = prompter_tokenizer(
internals/pipelines/replace_background.py CHANGED
@@ -12,13 +12,14 @@ from PIL import Image, ImageFilter, ImageOps
12
 
13
  import internals.util.image as ImageUtil
14
  from internals.data.result import Result
 
15
  from internals.pipelines.controlnets import ControlNet
16
  from internals.pipelines.remove_background import RemoveBackgroundV2
17
  from internals.pipelines.upscaler import Upscaler
18
  from internals.util.commons import download_image
19
 
20
 
21
- class ReplaceBackground:
22
  def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
23
  controlnet = ControlNetModel.from_pretrained(
24
  "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
 
12
 
13
  import internals.util.image as ImageUtil
14
  from internals.data.result import Result
15
+ from internals.pipelines.commons import AbstractPipeline
16
  from internals.pipelines.controlnets import ControlNet
17
  from internals.pipelines.remove_background import RemoveBackgroundV2
18
  from internals.pipelines.upscaler import Upscaler
19
  from internals.util.commons import download_image
20
 
21
 
22
+ class ReplaceBackground(AbstractPipeline):
23
  def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
24
  controlnet = ControlNetModel.from_pretrained(
25
  "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16
internals/pipelines/safety_checker.py CHANGED
@@ -18,13 +18,24 @@ def cosine_distance(image_embeds, text_embeds):
18
 
19
 
20
  class SafetyChecker:
 
 
21
  def load(self):
 
 
 
22
  self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
23
  "CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
24
  ).to("cuda")
25
 
 
 
26
  def apply(self, pipeline: AbstractPipeline):
 
 
27
  model = self.model if not get_nsfw_access() else None
 
 
28
  if hasattr(pipeline, "pipe"):
29
  pipeline.pipe.safety_checker = model
30
  if hasattr(pipeline, "pipe2"):
 
18
 
19
 
20
  class SafetyChecker:
21
+ __loaded = False
22
+
23
  def load(self):
24
+ if self.__loaded:
25
+ return
26
+
27
  self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
28
  "CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
29
  ).to("cuda")
30
 
31
+ self.__loaded = True
32
+
33
  def apply(self, pipeline: AbstractPipeline):
34
+ self.load()
35
+
36
  model = self.model if not get_nsfw_access() else None
37
+ if not pipeline:
38
+ return
39
  if hasattr(pipeline, "pipe"):
40
  pipeline.pipe.safety_checker = model
41
  if hasattr(pipeline, "pipe2"):
internals/util/config.py CHANGED
@@ -7,10 +7,17 @@ nsfw_threshold = 0.0
7
  nsfw_access = False
8
  access_token = ""
9
  root_dir = ""
 
 
10
 
11
  num_return_sequences = 4 # the number of results to generate
12
 
13
 
 
 
 
 
 
14
  def set_root_dir(main_file: str):
15
  global root_dir
16
  root_dir = os.path.dirname(os.path.abspath(main_file))
@@ -28,6 +35,11 @@ def set_configs_from_task(task: Task):
28
  access_token = task.get_access_token()
29
 
30
 
 
 
 
 
 
31
  def get_root_dir():
32
  global root_dir
33
  return root_dir
@@ -48,6 +60,11 @@ def get_nsfw_access():
48
  return nsfw_access
49
 
50
 
 
 
 
 
 
51
  def api_headers():
52
  return {
53
  "Access-Token": access_token,
 
7
  nsfw_access = False
8
  access_token = ""
9
  root_dir = ""
10
+ model_dir = ""
11
+ hf_token = "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"
12
 
13
  num_return_sequences = 4 # the number of results to generate
14
 
15
 
16
+ def set_model_dir(dir: str):
17
+ global model_dir
18
+ model_dir = dir
19
+
20
+
21
  def set_root_dir(main_file: str):
22
  global root_dir
23
  root_dir = os.path.dirname(os.path.abspath(main_file))
 
35
  access_token = task.get_access_token()
36
 
37
 
38
+ def get_model_dir():
39
+ global model_dir
40
+ return model_dir
41
+
42
+
43
  def get_root_dir():
44
  global root_dir
45
  return root_dir
 
60
  return nsfw_access
61
 
62
 
63
+ def get_hf_token():
64
+ global hf_token
65
+ return hf_token
66
+
67
+
68
  def api_headers():
69
  return {
70
  "Access-Token": access_token,
internals/util/model_downloader.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ from pathlib import Path
5
+ from threading import Thread
6
+ from typing import Any, Dict, List
7
+
8
+ import requests
9
+ from tqdm import tqdm
10
+
11
+
12
+ class BaseModelDownloader:
13
+ """
14
+ A utility for fast download of base model from S3 or any CDN served storage.
15
+ Works by downloading multiple files in parallel and dividing large files
16
+ into smaller chunks and combining them at the end.
17
+
18
+ Currently it uses multithreading (not multiprocessing) assuming GIL won't
19
+ interfere with network/disk IO.
20
+
21
+ Created by: KP
22
+ """
23
+
24
+ def __init__(self, urls: List[str], url_paths: List[str], out_dir: Path):
25
+ self.urls = urls
26
+ self.url_paths = url_paths
27
+ shutil.rmtree(out_dir, ignore_errors=True)
28
+ out_dir.mkdir(parents=True, exist_ok=True)
29
+ self.out_dir = out_dir
30
+
31
+ def download(self):
32
+ threads = []
33
+ batch_urls = {}
34
+
35
+ for url, url_path in zip(self.urls, self.url_paths):
36
+ out_dir = self.out_dir / url_path
37
+ self.out_dir.parent.mkdir(parents=True, exist_ok=True)
38
+ if url.endswith(".bin"):
39
+ if "unet/" in url_path:
40
+ thread = Thread(
41
+ target=self.__download_parallel, args=(url, out_dir, 6)
42
+ )
43
+ thread.start()
44
+ threads.append(thread)
45
+ else:
46
+ thread = Thread(
47
+ target=self.__download_files, args=([url], [out_dir])
48
+ )
49
+ thread.start()
50
+ threads.append(thread)
51
+ pass
52
+ else:
53
+ batch_urls[url] = out_dir
54
+
55
+ if batch_urls:
56
+ thread = Thread(
57
+ target=self.__download_files,
58
+ args=(list(batch_urls.keys()), list(batch_urls.values())),
59
+ )
60
+ thread.start()
61
+ threads.append(thread)
62
+ pass
63
+
64
+ for thread in threads:
65
+ thread.join()
66
+
67
+ def __download_parallel(self, url, output_filename, num_parts=4):
68
+ response = requests.head(url)
69
+ total_size = int(response.headers.get("content-length", 0))
70
+ print("total_size", total_size)
71
+
72
+ chunk_size = total_size // num_parts
73
+ ranges = [
74
+ (i * chunk_size, (i + 1) * chunk_size - 1) for i in range(num_parts - 1)
75
+ ]
76
+ ranges.append((ranges[-1][1] + 1, total_size))
77
+
78
+ print(ranges)
79
+
80
+ save_dir = Path.home() / ".cache" / "download_parts"
81
+ os.makedirs(save_dir, exist_ok=True)
82
+
83
+ threads = []
84
+ for i, (start, end) in enumerate(ranges):
85
+ thread = Thread(
86
+ target=self.__download_part, args=(url, start, end, i, save_dir)
87
+ )
88
+ thread.start()
89
+ threads.append(thread)
90
+
91
+ for thread in threads:
92
+ thread.join()
93
+
94
+ self.__combine_parts(save_dir, output_filename, num_parts)
95
+ os.rmdir(save_dir)
96
+
97
+ def __combine_parts(self, save_dir, output_filename, num_parts):
98
+ part_files = [os.path.join(save_dir, f"part_{i}.tmp") for i in range(num_parts)]
99
+
100
+ output_filename.parent.mkdir(parents=True, exist_ok=True)
101
+ with open(output_filename, "wb") as output_file:
102
+ for part_file in part_files:
103
+ print("combining: ", part_file)
104
+ with open(part_file, "rb") as part:
105
+ output_file.write(part.read())
106
+
107
+ out_file_size = output_file.tell()
108
+ print("out_file_size", out_file_size)
109
+
110
+ for part_file in part_files:
111
+ os.remove(part_file)
112
+
113
+ def __download_part(self, url, start_byte, end_byte, part_num, save_dir):
114
+ headers = {"Range": f"bytes={start_byte}-{end_byte}"}
115
+ response = requests.get(url, headers=headers, stream=True)
116
+
117
+ part_filename = os.path.join(save_dir, f"part_{part_num}.tmp")
118
+ print("Downloading part: ", url, part_filename, end_byte - start_byte)
119
+
120
+ with open(part_filename, "wb") as part_file, tqdm(
121
+ desc=str(part_filename),
122
+ total=end_byte - start_byte,
123
+ unit="B",
124
+ unit_scale=True,
125
+ unit_divisor=1024,
126
+ ) as bar:
127
+ for chunk in response.iter_content(chunk_size=8192):
128
+ if chunk:
129
+ size = part_file.write(chunk)
130
+ bar.update(size)
131
+
132
+ return part_filename
133
+
134
+ def __download_files(self, urls, out_paths: List[Path]):
135
+ for url, out_path in zip(urls, out_paths):
136
+ out_path.parent.mkdir(parents=True, exist_ok=True)
137
+ with requests.get(url, stream=True) as r:
138
+ print("Downloading: ", url)
139
+ total_size = int(r.headers.get("content-length", 0))
140
+ chunk_size = 8192
141
+ r.raise_for_status()
142
+ with open(out_path, "wb") as f, tqdm(
143
+ desc=str(out_path),
144
+ total=total_size,
145
+ unit="B",
146
+ unit_scale=True,
147
+ unit_divisor=1024,
148
+ ) as bar:
149
+ for data in r.iter_content(chunk_size=chunk_size):
150
+ size = f.write(data)
151
+ bar.update(size)
requirements.txt CHANGED
@@ -22,20 +22,19 @@ kornia==0.5.0
22
  pytorch-lightning==1.2.9
23
  mmpose==0.29.0
24
  mmdet==2.28.2
25
- mmengine
26
- pydash
27
- scikit-learn
28
- accelerate
29
- pandas
30
- xformers
31
- torchvision
32
- scikit-image
33
- omegaconf
34
- webdataset
35
- git+https://github.com/cloneofsimo/lora.git
36
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
37
  python-dateutil==2.8.2
38
  PyYAML
39
- torchvision
40
- imgaug
41
- tqdm
 
22
  pytorch-lightning==1.2.9
23
  mmpose==0.29.0
24
  mmdet==2.28.2
25
+ https://comic-assets.s3.ap-south-1.amazonaws.com/packages/v1/lora-diffusion-0.1.7.zip
26
+ mmengine==0.8.4
27
+ pydash==7.0.6
28
+ scikit-learn==1.3.0
29
+ accelerate==0.22.0
30
+ pandas==2.0.3
31
+ xformers==0.0.21
32
+ scikit-image==0.19.3
33
+ omegaconf==2.3.0
34
+ webdataset==0.2.48
 
35
  https://comic-assets.s3.ap-south-1.amazonaws.com/packages/mmcv_full-1.7.0-cp39-cp39-linux_x86_64.whl
36
  python-dateutil==2.8.2
37
  PyYAML
38
+ torchvision==0.15.2
39
+ imgaug==0.4.0
40
+ tqdm==4.64.1