jayparmr commited on
Commit
7fbdac4
1 Parent(s): 9387217

Upload folder using huggingface_hub

Browse files
external/scripts/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+ from internals.util import getcwd
5
+
6
+ path = os.path.join(getcwd(), "external/scripts")
7
+
8
+ __scripts__ = []
9
+ for name in os.listdir(path):
10
+ name = name.split("/")[-1].replace(".py", "")
11
+ imp = importlib.import_module(f"external.scripts.{name}")
12
+ if hasattr(imp, "Script") and imp not in __scripts__:
13
+ __scripts__.append(imp)
external/scripts/day_night_ip2p.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionInstructPix2PixPipeline
3
+
4
+ import internals.util.image as ImageUtil
5
+ from internals.data.dataAccessor import update_db
6
+ from internals.data.task import Task
7
+ from internals.util.cache import clear_cuda_and_gc
8
+ from internals.util.commons import download_image, upload_images
9
+ from internals.util.config import get_hf_token
10
+ from internals.util.slack import Slack
11
+
12
+ slack = Slack()
13
+
14
+
15
+ class Script:
16
+ def __init__(self, **kwargs):
17
+ self.__name__ = "day_night_ip2p"
18
+
19
+ @update_db
20
+ @slack.auto_send_alert
21
+ def __call__(self, task: Task, args: dict):
22
+ clear_cuda_and_gc()
23
+
24
+ model_id = args.get("model_id", None)
25
+ steps = args.get("steps", 50)
26
+ image_guidance_scale = args.get("image_guidance_scale", 1.5)
27
+ guidance_scale = args.get("guidance_scale", 7.5)
28
+
29
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
30
+ model_id,
31
+ use_auth_token=get_hf_token(),
32
+ torch_dtype=torch.float16,
33
+ safety_checker=None,
34
+ ).to("cuda")
35
+ pipe.enable_xformers_memory_efficient_attention()
36
+
37
+ prompt = ["convert to night", "convert to evening", "convert to midnight"]
38
+ image = download_image(task.get_imageUrl())
39
+ image = ImageUtil.resize_image(image, 1024)
40
+
41
+ images = []
42
+ for p in prompt:
43
+ print("Generating: ", p)
44
+ image = pipe.__call__(
45
+ prompt=p,
46
+ num_inference_steps=steps,
47
+ image=image,
48
+ guidance_scale=guidance_scale,
49
+ num_images_per_prompt=1,
50
+ image_guidance_scale=image_guidance_scale,
51
+ ).images[0]
52
+ images.append(image)
53
+
54
+ generated_image_urls = upload_images(
55
+ images, "_" + self.__name__, task.get_taskId()
56
+ )
57
+
58
+ pipe = None
59
+ del pipe
60
+
61
+ clear_cuda_and_gc()
62
+
63
+ return {"generated_image_urls": generated_image_urls}
inference.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  import traceback
3
  from typing import List, Optional
4
 
 
5
  import torch
 
6
 
7
  import internals.util.prompt as prompt_util
8
  from internals.data.dataAccessor import update_db, update_db_source_failed
@@ -54,6 +56,8 @@ safety_checker = SafetyChecker()
54
  slack = Slack()
55
  avatar = Avatar()
56
 
 
 
57
 
58
  def get_patched_prompt(task: Task):
59
  return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
@@ -533,6 +537,32 @@ def replace_bg(task: Task):
533
  }
534
 
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  def load_model_by_task(task: Task):
537
  if not text2img_pipe.is_loaded():
538
  text2img_pipe.load(get_model_dir())
@@ -587,6 +617,8 @@ def predict_fn(data, pipe):
587
  task = Task(data)
588
  print("task is ", data)
589
 
 
 
590
  FailureHandler.handle(task)
591
 
592
  try:
@@ -629,6 +661,8 @@ def predict_fn(data, pipe):
629
  return linearart(task)
630
  elif task_type == TaskType.REPLACE_BG:
631
  return replace_bg(task)
 
 
632
  elif task_type == TaskType.SYSTEM_CMD:
633
  os.system(task.get_prompt())
634
  else:
 
2
  import traceback
3
  from typing import List, Optional
4
 
5
+ import pydash as _
6
  import torch
7
+ from numpy import who
8
 
9
  import internals.util.prompt as prompt_util
10
  from internals.data.dataAccessor import update_db, update_db_source_failed
 
56
  slack = Slack()
57
  avatar = Avatar()
58
 
59
+ custom_scripts: List = []
60
+
61
 
62
  def get_patched_prompt(task: Task):
63
  return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
 
537
  }
538
 
539
 
540
+ def custom_action(task: Task):
541
+ from external.scripts import __scripts__
542
+
543
+ global custom_scripts
544
+ kwargs = {
545
+ "CONTROLNET": controlnet,
546
+ "LORASTYLE": lora_style,
547
+ }
548
+
549
+ torch.manual_seed(task.get_seed())
550
+
551
+ for script in __scripts__:
552
+ script = script.Script(**kwargs)
553
+ existing_script = _.find(
554
+ custom_scripts, lambda x: x.__name__ == script.__name__
555
+ )
556
+ if existing_script:
557
+ script = existing_script
558
+ else:
559
+ custom_scripts.append(script)
560
+
561
+ data = task.get_action_data()
562
+ if data["name"] == script.__name__:
563
+ return script(task, data)
564
+
565
+
566
  def load_model_by_task(task: Task):
567
  if not text2img_pipe.is_loaded():
568
  text2img_pipe.load(get_model_dir())
 
617
  task = Task(data)
618
  print("task is ", data)
619
 
620
+ clear_cuda_and_gc()
621
+
622
  FailureHandler.handle(task)
623
 
624
  try:
 
661
  return linearart(task)
662
  elif task_type == TaskType.REPLACE_BG:
663
  return replace_bg(task)
664
+ elif task_type == TaskType.CUSTOM_ACTION:
665
+ return custom_action(task)
666
  elif task_type == TaskType.SYSTEM_CMD:
667
  os.system(task.get_prompt())
668
  else:
internals/data/dataAccessor.py CHANGED
@@ -1,9 +1,9 @@
1
  import traceback
2
  from typing import Dict, List, Optional
3
 
4
- from requests.adapters import Retry, HTTPAdapter
5
  import requests
6
  from pydash import includes
 
7
 
8
  from internals.data.task import Task
9
  from internals.util.config import api_endpoint, api_headers
@@ -104,9 +104,13 @@ def update_db_source_failed(sourceId, userId):
104
 
105
  def update_db(func):
106
  def caller(*args, **kwargs):
107
- if type(args[0]) is not Task:
 
 
 
 
 
108
  raise Exception("First argument must be a Task object")
109
- task = args[0]
110
  try:
111
  updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
112
  rargs = func(*args, **kwargs)
 
1
  import traceback
2
  from typing import Dict, List, Optional
3
 
 
4
  import requests
5
  from pydash import includes
6
+ from requests.adapters import HTTPAdapter, Retry
7
 
8
  from internals.data.task import Task
9
  from internals.util.config import api_endpoint, api_headers
 
104
 
105
  def update_db(func):
106
  def caller(*args, **kwargs):
107
+ task = None
108
+ for arg in args:
109
+ if type(arg) is Task:
110
+ task = arg
111
+ break
112
+ if task is None:
113
  raise Exception("First argument must be a Task object")
 
114
  try:
115
  updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
116
  rargs = func(*args, **kwargs)
internals/data/task.py CHANGED
@@ -18,6 +18,7 @@ class TaskType(Enum):
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
 
21
  SYSTEM_CMD = "SYSTEM_CMD"
22
 
23
 
@@ -148,6 +149,10 @@ class Task:
148
  def get_base_dimension(self):
149
  return self.__data.get("base_dimension", None)
150
 
 
 
 
 
151
  def get_raw(self) -> dict:
152
  return self.__data.copy()
153
 
 
18
  SCRIBBLE = "SCRIBBLE"
19
  LINEARART = "LINEARART"
20
  REPLACE_BG = "REPLACE_BG"
21
+ CUSTOM_ACTION = "CUSTOM_ACTION"
22
  SYSTEM_CMD = "SYSTEM_CMD"
23
 
24
 
 
149
  def get_base_dimension(self):
150
  return self.__data.get("base_dimension", None)
151
 
152
+ def get_action_data(self) -> dict:
153
+ "If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
154
+ return self.__data.get("action_data", {})
155
+
156
  def get_raw(self) -> dict:
157
  return self.__data.copy()
158
 
internals/pipelines/controlnets.py CHANGED
@@ -151,7 +151,6 @@ class ControlNet(AbstractPipeline):
151
 
152
  self.__load_pipeline(model, pipeline_type)
153
 
154
- self.network_model = model
155
  self.__current_task_name = task_name
156
 
157
  clear_cuda_and_gc()
@@ -247,6 +246,11 @@ class ControlNet(AbstractPipeline):
247
  if hasattr(self, "pipe2"):
248
  setattr(self.pipe2, "adapter", network_model)
249
 
 
 
 
 
 
250
  clear_cuda_and_gc()
251
 
252
  def process(self, **kwargs):
 
151
 
152
  self.__load_pipeline(model, pipeline_type)
153
 
 
154
  self.__current_task_name = task_name
155
 
156
  clear_cuda_and_gc()
 
246
  if hasattr(self, "pipe2"):
247
  setattr(self.pipe2, "adapter", network_model)
248
 
249
+ if hasattr(self, "pipe"):
250
+ self.pipe = self.pipe.to("cuda")
251
+ if hasattr(self, "pipe2"):
252
+ self.pipe2 = self.pipe2.to("cuda")
253
+
254
  clear_cuda_and_gc()
255
 
256
  def process(self, **kwargs):
internals/pipelines/high_res.py CHANGED
@@ -5,7 +5,8 @@ from PIL import Image
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
8
- from internals.util.config import get_model_dir, get_base_dimension
 
9
 
10
 
11
  class HighRes(AbstractPipeline):
@@ -32,6 +33,8 @@ class HighRes(AbstractPipeline):
32
  guidance_scale: int = 9,
33
  **kwargs,
34
  ):
 
 
35
  images = [image.resize((width, height)) for image in images]
36
  kwargs = {
37
  "prompt": prompt,
 
5
 
6
  from internals.data.result import Result
7
  from internals.pipelines.commons import AbstractPipeline, Img2Img
8
+ from internals.util.cache import clear_cuda_and_gc
9
+ from internals.util.config import get_base_dimension, get_model_dir
10
 
11
 
12
  class HighRes(AbstractPipeline):
 
33
  guidance_scale: int = 9,
34
  **kwargs,
35
  ):
36
+ clear_cuda_and_gc()
37
+
38
  images = [image.resize((width, height)) for image in images]
39
  kwargs = {
40
  "prompt": prompt,
internals/util/__init__.py CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from internals.util.config import get_root_dir
4
+
5
+
6
+ def getcwd():
7
+ return get_root_dir()
internals/util/lora_style.py CHANGED
@@ -10,8 +10,8 @@ from lora_diffusion import patch_pipe, tune_lora_scale
10
  from pydash import chain
11
 
12
  from internals.data.dataAccessor import getStyles
13
- from internals.util.config import get_is_sdxl
14
  from internals.util.commons import download_file
 
15
 
16
 
17
  class LoraStyle:
@@ -113,9 +113,6 @@ class LoraStyle:
113
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
114
  "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
115
  pipe = [pipe] if not isinstance(pipe, list) else pipe
116
- if get_is_sdxl():
117
- print("Warning: Lora is not supported on SDXL")
118
- return self.EmptyLoraPatcher(pipe)
119
 
120
  if key in self.__styles:
121
  style = self.__styles[key]
 
10
  from pydash import chain
11
 
12
  from internals.data.dataAccessor import getStyles
 
13
  from internals.util.commons import download_file
14
+ from internals.util.config import get_is_sdxl
15
 
16
 
17
  class LoraStyle:
 
113
  ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
114
  "Returns a lora patcher for the given `key` and `pipe`. `pipe` can also be a list of pipes"
115
  pipe = [pipe] if not isinstance(pipe, list) else pipe
 
 
 
116
 
117
  if key in self.__styles:
118
  style = self.__styles[key]
internals/util/slack.py CHANGED
@@ -55,7 +55,12 @@ class Slack:
55
  def auto_send_alert(self, func):
56
  def inner(*args, **kwargs):
57
  rargs = func(*args, **kwargs)
58
- self.send_alert(args[0], rargs)
 
 
 
 
 
59
  return rargs
60
 
61
  return inner
 
55
  def auto_send_alert(self, func):
56
  def inner(*args, **kwargs):
57
  rargs = func(*args, **kwargs)
58
+ task = Task({})
59
+ for arg in args:
60
+ if type(arg) is Task:
61
+ task = arg
62
+ break
63
+ self.send_alert(task, rargs)
64
  return rargs
65
 
66
  return inner