|
from typing import List, Optional |
|
|
|
from internals.data.task import Task |
|
from internals.pipelines.commons import Text2Img |
|
from internals.pipelines.img_classifier import ImageClassifier |
|
from internals.pipelines.img_to_text import Image2Text |
|
from internals.pipelines.prompt_modifier import PromptModifier |
|
from internals.util.anomaly import remove_colors |
|
from internals.util.avatar import Avatar |
|
from internals.util.config import get_num_return_sequences |
|
from internals.util.lora_style import LoraStyle |
|
|
|
|
|
def get_patched_prompt( |
|
task: Task, |
|
avatar: Avatar, |
|
lora_style: LoraStyle, |
|
prompt_modifier: PromptModifier, |
|
): |
|
def add_style_and_character(prompt: List[str], additional: Optional[str] = None): |
|
for i in range(len(prompt)): |
|
prompt[i] = avatar.add_code_names(prompt[i]) |
|
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) |
|
if additional: |
|
prompt[i] = additional + " " + prompt[i] |
|
|
|
prompt = task.get_prompt() |
|
|
|
if task.is_prompt_engineering(): |
|
prompt = prompt_modifier.modify(prompt) |
|
else: |
|
prompt = [prompt] * get_num_return_sequences() |
|
|
|
ori_prompt = [task.get_prompt()] * get_num_return_sequences() |
|
|
|
class_name = None |
|
add_style_and_character(ori_prompt, class_name) |
|
add_style_and_character(prompt, class_name) |
|
|
|
print({"prompts": prompt}) |
|
|
|
return (prompt, ori_prompt) |
|
|
|
|
|
def get_patched_prompt_text2img( |
|
task: Task, |
|
avatar: Avatar, |
|
lora_style: LoraStyle, |
|
prompt_modifier: PromptModifier, |
|
) -> Text2Img.Params: |
|
def add_style_and_character(prompt: str, prepend: str = ""): |
|
prompt = avatar.add_code_names(prompt) |
|
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) |
|
prompt = prepend + prompt |
|
return prompt |
|
|
|
if task.get_prompt_left() and task.get_prompt_right(): |
|
|
|
prepend = "" |
|
if task.is_prompt_engineering(): |
|
mod_prompt = prompt_modifier.modify(task.get_prompt()) |
|
else: |
|
mod_prompt = [task.get_prompt()] * get_num_return_sequences() |
|
|
|
prompt, prompt_left, prompt_right = [], [], [] |
|
for i in range(len(mod_prompt)): |
|
mp = mod_prompt[i].replace(task.get_prompt(), "") |
|
prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp) |
|
prompt_left.append( |
|
add_style_and_character(task.get_prompt_left(), prepend) + mp |
|
) |
|
prompt_right.append( |
|
add_style_and_character(task.get_prompt_right(), prepend) + mp |
|
) |
|
|
|
params = Text2Img.Params( |
|
prompt=prompt, |
|
prompt_left=prompt_left, |
|
prompt_right=prompt_right, |
|
) |
|
else: |
|
if task.is_prompt_engineering(): |
|
mod_prompt = prompt_modifier.modify(task.get_prompt()) |
|
else: |
|
mod_prompt = [task.get_prompt()] * get_num_return_sequences() |
|
mod_prompt = [add_style_and_character(mp) for mp in mod_prompt] |
|
|
|
params = Text2Img.Params( |
|
prompt=[add_style_and_character(task.get_prompt())] |
|
* get_num_return_sequences(), |
|
modified_prompt=mod_prompt, |
|
) |
|
|
|
print(params) |
|
|
|
return params |
|
|
|
|
|
def get_patched_prompt_tile_upscale( |
|
task: Task, |
|
avatar: Avatar, |
|
lora_style: LoraStyle, |
|
img_classifier: ImageClassifier, |
|
img2text: Image2Text, |
|
): |
|
if task.get_prompt(): |
|
prompt = task.get_prompt() |
|
else: |
|
prompt = img2text.process(task.get_imageUrl()) |
|
|
|
|
|
if task.PROMPT.has_placeholder_blip_merge(): |
|
blip = img2text.process(task.get_imageUrl()) |
|
prompt = task.PROMPT.merge_blip(blip) |
|
|
|
|
|
prompt = remove_colors(prompt) |
|
|
|
prompt = avatar.add_code_names(prompt) |
|
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) |
|
|
|
if not task.get_style(): |
|
class_name = img_classifier.classify( |
|
task.get_imageUrl(), task.get_width(), task.get_height() |
|
) |
|
else: |
|
class_name = "" |
|
prompt = class_name + " " + prompt |
|
prompt = prompt.strip() |
|
|
|
print({"prompt": prompt}) |
|
|
|
return prompt |
|
|