jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
4.24 kB
from typing import List, Optional
from internals.data.task import Task
from internals.pipelines.commons import Text2Img
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.prompt_modifier import PromptModifier
from internals.util.anomaly import remove_colors
from internals.util.avatar import Avatar
from internals.util.config import get_num_return_sequences
from internals.util.lora_style import LoraStyle
def get_patched_prompt(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
):
def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
for i in range(len(prompt)):
prompt[i] = avatar.add_code_names(prompt[i])
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
if additional:
prompt[i] = additional + " " + prompt[i]
prompt = task.get_prompt()
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * get_num_return_sequences()
ori_prompt = [task.get_prompt()] * get_num_return_sequences()
class_name = None
add_style_and_character(ori_prompt, class_name)
add_style_and_character(prompt, class_name)
print({"prompts": prompt})
return (prompt, ori_prompt)
def get_patched_prompt_text2img(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
prompt_modifier: PromptModifier,
) -> Text2Img.Params:
def add_style_and_character(prompt: str, prepend: str = ""):
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
prompt = prepend + prompt
return prompt
if task.get_prompt_left() and task.get_prompt_right():
# prepend = "2characters, "
prepend = ""
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * get_num_return_sequences()
prompt, prompt_left, prompt_right = [], [], []
for i in range(len(mod_prompt)):
mp = mod_prompt[i].replace(task.get_prompt(), "")
prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
prompt_left.append(
add_style_and_character(task.get_prompt_left(), prepend) + mp
)
prompt_right.append(
add_style_and_character(task.get_prompt_right(), prepend) + mp
)
params = Text2Img.Params(
prompt=prompt,
prompt_left=prompt_left,
prompt_right=prompt_right,
)
else:
if task.is_prompt_engineering():
mod_prompt = prompt_modifier.modify(task.get_prompt())
else:
mod_prompt = [task.get_prompt()] * get_num_return_sequences()
mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
params = Text2Img.Params(
prompt=[add_style_and_character(task.get_prompt())]
* get_num_return_sequences(),
modified_prompt=mod_prompt,
)
print(params)
return params
def get_patched_prompt_tile_upscale(
task: Task,
avatar: Avatar,
lora_style: LoraStyle,
img_classifier: ImageClassifier,
img2text: Image2Text,
):
if task.get_prompt():
prompt = task.get_prompt()
else:
prompt = img2text.process(task.get_imageUrl())
# merge blip
if task.PROMPT.has_placeholder_blip_merge():
blip = img2text.process(task.get_imageUrl())
prompt = task.PROMPT.merge_blip(blip)
# remove anomalies in prompt
prompt = remove_colors(prompt)
prompt = avatar.add_code_names(prompt)
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
if not task.get_style():
class_name = img_classifier.classify(
task.get_imageUrl(), task.get_width(), task.get_height()
)
else:
class_name = ""
prompt = class_name + " " + prompt
prompt = prompt.strip()
print({"prompt": prompt})
return prompt