jayparmr commited on
Commit
0eec7f4
1 Parent(s): ae524a9

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. inference.py +21 -6
  2. internals/data/task.py +16 -0
  3. internals/util/anomaly.py +60 -0
inference.py CHANGED
@@ -12,14 +12,21 @@ from internals.pipelines.inpainter import InPainter
12
  from internals.pipelines.pose_detector import PoseDetector
13
  from internals.pipelines.prompt_modifier import PromptModifier
14
  from internals.pipelines.safety_checker import SafetyChecker
 
15
  from internals.util.args import apply_style_args
16
  from internals.util.avatar import Avatar
17
- from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda,
18
- clear_cuda_and_gc)
19
- from internals.util.commons import (download_image, pickPoses, upload_image,
20
- upload_images)
21
- from internals.util.config import (num_return_sequences, set_configs_from_task,
22
- set_root_dir)
 
 
 
 
 
 
23
  from internals.util.failure_hander import FailureHandler
24
  from internals.util.lora_style import LoraStyle
25
  from internals.util.slack import Slack
@@ -123,6 +130,14 @@ def get_patched_prompt_tile_upscale(task: Task):
123
  else:
124
  prompt = img2text.process(task.get_imageUrl())
125
 
 
 
 
 
 
 
 
 
126
  prompt = avatar.add_code_names(prompt)
127
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
128
 
 
12
  from internals.pipelines.pose_detector import PoseDetector
13
  from internals.pipelines.prompt_modifier import PromptModifier
14
  from internals.pipelines.safety_checker import SafetyChecker
15
+ from internals.util.anomaly import remove_colors
16
  from internals.util.args import apply_style_args
17
  from internals.util.avatar import Avatar
18
+ from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
19
+ from internals.util.commons import (
20
+ download_image,
21
+ pickPoses,
22
+ upload_image,
23
+ upload_images,
24
+ )
25
+ from internals.util.config import (
26
+ num_return_sequences,
27
+ set_configs_from_task,
28
+ set_root_dir,
29
+ )
30
  from internals.util.failure_hander import FailureHandler
31
  from internals.util.lora_style import LoraStyle
32
  from internals.util.slack import Slack
 
130
  else:
131
  prompt = img2text.process(task.get_imageUrl())
132
 
133
+ # merge blip
134
+ if task.PROMPT.has_placeholder_blip_merge():
135
+ blip = img2text.process(task.get_imageUrl())
136
+ prompt = task.PROMPT.merge_blip(blip)
137
+
138
+ # remove anomalies in prompt
139
+ prompt = remove_colors(prompt)
140
+
141
  prompt = avatar.add_code_names(prompt)
142
  prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
143
 
internals/data/task.py CHANGED
@@ -1,4 +1,5 @@
1
  from enum import Enum
 
2
  from typing import Union
3
 
4
  import numpy as np
@@ -144,3 +145,18 @@ class Task:
144
 
145
  def get_raw(self) -> dict:
146
  return self.__data.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
+ from functools import lru_cache
3
  from typing import Union
4
 
5
  import numpy as np
 
145
 
146
  def get_raw(self) -> dict:
147
  return self.__data.copy()
148
+
149
+ @property
150
+ @lru_cache(1)
151
+ def PROMPT(self):
152
+ class PromptMethods:
153
+ def __init__(self, task: Task):
154
+ self.__task = task
155
+
156
+ def has_placeholder_blip_merge(self) -> bool:
157
+ return "<blip:[merge]>" in self.__task.get_prompt()
158
+
159
+ def merge_blip(self, text: str) -> str:
160
+ return self.__task.get_prompt().replace("<blip:[merge]>", text)
161
+
162
+ return PromptMethods(self)
internals/util/anomaly.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ __colors = [
4
+ "Amber",
5
+ "Aqua",
6
+ "Azure",
7
+ "Beige",
8
+ "Black",
9
+ "Blue",
10
+ "Bronze",
11
+ "Brown",
12
+ "Chartreuse",
13
+ "Clay",
14
+ "Clay",
15
+ "Cyan",
16
+ "Dark",
17
+ "Gainsboro",
18
+ "Golden",
19
+ "Grape",
20
+ "Green",
21
+ "Grey",
22
+ "Indigo",
23
+ "Ivory",
24
+ "Light",
25
+ "Lime",
26
+ "Magenta",
27
+ "Maroon",
28
+ "Metallic",
29
+ "Mint",
30
+ "Mistry",
31
+ "Mustard",
32
+ "Navy",
33
+ "Neon",
34
+ "Off",
35
+ "Olive",
36
+ "Orange",
37
+ "Orange",
38
+ "Pea",
39
+ "Peru",
40
+ "Pink",
41
+ "Plum",
42
+ "Purple",
43
+ "Red",
44
+ "Ruby",
45
+ "Rust",
46
+ "Silver",
47
+ "Snow",
48
+ "Teal",
49
+ "Turquoise",
50
+ "Violet",
51
+ "Wheat",
52
+ "White",
53
+ "Yellow",
54
+ ]
55
+
56
+
57
+ def remove_colors(text: str) -> str:
58
+ for color in __colors:
59
+ text = re.sub(color, "", text, flags=re.I)
60
+ return text