File size: 4,235 Bytes
1bc457e 22df957 1bc457e 22df957 1bc457e 22df957 1bc457e 22df957 1bc457e 22df957 1bc457e 22df957 1bc457e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import List, Optional
from internals.data.task import Task
from internals.pipelines.commons import Text2Img
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.prompt_modifier import PromptModifier
from internals.util.anomaly import remove_colors
from internals.util.avatar import Avatar
from internals.util.config import get_num_return_sequences
from internals.util.lora_style import LoraStyle
def get_patched_prompt(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
):
def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
for i in range(len(prompt)):
prompt[i] = avatar.add_code_names(prompt[i])
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
if additional:
prompt[i] = additional + " " + prompt[i]
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * get_num_return_sequences()
ori_prompt = [task.get_prompt()] * get_num_return_sequences()
class_name = None
add_style_and_character(ori_prompt, class_name)
add_style_and_character(prompt, class_name)
print({"prompts": prompt})
return (prompt, ori_prompt)
def get_patched_prompt_text2img(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
) -> Text2Img.Params:
def add_style_and_character(prompt: str, prepend: str = ""):
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
prompt = prepend + prompt
return prompt
if task.get_prompt_left() and task.get_prompt_right():
# prepend = "2characters, "
prepend = ""
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * get_num_return_sequences()
prompt, prompt_left, prompt_right = [], [], []
for i in range(len(mod_prompt)):
mp = mod_prompt[i].replace(task.get_prompt(), "")
prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
prompt_left.append(
add_style_and_character(task.get_prompt_left(), prepend) + mp
)
prompt_right.append(
add_style_and_character(task.get_prompt_right(), prepend) + mp
)
params = Text2Img.Params(
prompt=prompt,
prompt_left=prompt_left,
prompt_right=prompt_right,
)
else:
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * get_num_return_sequences()
mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
params = Text2Img.Params(
prompt=[add_style_and_character(task.get_prompt())]
* get_num_return_sequences(),
modified_prompt=mod_prompt,
)
print(params)
return params
def get_patched_prompt_tile_upscale(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
img_classifier: ImageClassifier,
img2text: Image2Text,
):
if task.get_prompt():
prompt = task.get_prompt()
else:
prompt = img2text.process(task.get_imageUrl())
# merge blip
if task.PROMPT.has_placeholder_blip_merge():
blip = img2text.process(task.get_imageUrl())
prompt = task.PROMPT.merge_blip(blip)
# remove anomalies in prompt
prompt = remove_colors(prompt)
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
if not task.get_style():
class_name = img_classifier.classify(
task.get_imageUrl(), task.get_width(), task.get_height()
)
else:
class_name = ""
prompt = class_name + " " + prompt
prompt = prompt.strip()
print({"prompt": prompt})
return prompt
|