|
from typing import List, Optional |
|
|
|
import torch |
|
|
|
from internals.data.dataAccessor import update_db |
|
from internals.data.task import Task, TaskType |
|
from internals.pipelines.commons import Img2Img, Text2Img |
|
from internals.pipelines.controlnets import ControlNet |
|
from internals.pipelines.img_classifier import ImageClassifier |
|
from internals.pipelines.img_to_text import Image2Text |
|
from internals.pipelines.inpainter import InPainter |
|
from internals.pipelines.pose_detector import PoseDetector |
|
from internals.pipelines.prompt_modifier import PromptModifier |
|
from internals.pipelines.safety_checker import SafetyChecker |
|
from internals.util.anomaly import remove_colors |
|
from internals.util.args import apply_style_args |
|
from internals.util.avatar import Avatar |
|
from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda, |
|
clear_cuda_and_gc) |
|
from internals.util.commons import (download_image, pickPoses, upload_image, |
|
upload_images) |
|
from internals.util.config import (get_model_dir, num_return_sequences, |
|
set_configs_from_task, set_model_dir, |
|
set_root_dir) |
|
from internals.util.failure_hander import FailureHandler |
|
from internals.util.lora_style import LoraStyle |
|
from internals.util.slack import Slack |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
auto_mode = False |
|
|
|
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) |
|
pose_detector = PoseDetector() |
|
inpainter = InPainter() |
|
img2text = Image2Text() |
|
img_classifier = ImageClassifier() |
|
controlnet = ControlNet() |
|
lora_style = LoraStyle() |
|
text2img_pipe = Text2Img() |
|
img2img_pipe = Img2Img() |
|
safety_checker = SafetyChecker() |
|
slack = Slack() |
|
avatar = Avatar() |
|
|
|
|
|
def get_patched_prompt(task: Task): |
|
def add_style_and_character(prompt: List[str], additional: Optional[str] = None): |
|
for i in range(len(prompt)): |
|
prompt[i] = avatar.add_code_names(prompt[i]) |
|
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) |
|
if additional: |
|
prompt[i] = additional + " " + prompt[i] |
|
|
|
prompt = task.get_prompt() |
|
|
|
if task.is_prompt_engineering(): |
|
prompt = prompt_modifier.modify(prompt) |
|
else: |
|
prompt = [prompt] * num_return_sequences |
|
|
|
ori_prompt = [task.get_prompt()] * num_return_sequences |
|
|
|
class_name = None |
|
add_style_and_character(ori_prompt, class_name) |
|
add_style_and_character(prompt, class_name) |
|
|
|
print({"prompts": prompt}) |
|
|
|
return (prompt, ori_prompt) |
|
|
|
|
|
def get_patched_prompt_text2img(task: Task) -> Text2Img.Params: |
|
def add_style_and_character(prompt: str, prepend: str = ""): |
|
prompt = avatar.add_code_names(prompt) |
|
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) |
|
prompt = prepend + prompt |
|
return prompt |
|
|
|
if task.get_prompt_left() and task.get_prompt_right(): |
|
|
|
prepend = "" |
|
if task.is_prompt_engineering(): |
|
mod_prompt = prompt_modifier.modify(task.get_prompt()) |
|
else: |
|
mod_prompt = [task.get_prompt()] * num_return_sequences |
|
|
|
prompt, prompt_left, prompt_right = [], [], [] |
|
for i in range(len(mod_prompt)): |
|
mp = mod_prompt[i].replace(task.get_prompt(), "") |
|
prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp) |
|
prompt_left.append( |
|
add_style_and_character(task.get_prompt_left(), prepend) + mp |
|
) |
|
prompt_right.append( |
|
add_style_and_character(task.get_prompt_right(), prepend) + mp |
|
) |
|
|
|
params = Text2Img.Params( |
|
prompt=prompt, |
|
prompt_left=prompt_left, |
|
prompt_right=prompt_right, |
|
) |
|
else: |
|
if task.is_prompt_engineering(): |
|
mod_prompt = prompt_modifier.modify(task.get_prompt()) |
|
else: |
|
mod_prompt = [task.get_prompt()] * num_return_sequences |
|
mod_prompt = [add_style_and_character(mp) for mp in mod_prompt] |
|
|
|
params = Text2Img.Params( |
|
prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences, |
|
modified_prompt=mod_prompt, |
|
) |
|
|
|
print(params) |
|
|
|
return params |
|
|
|
|
|
def get_patched_prompt_tile_upscale(task: Task): |
|
if task.get_prompt(): |
|
prompt = task.get_prompt() |
|
else: |
|
prompt = img2text.process(task.get_imageUrl()) |
|
|
|
|
|
if task.PROMPT.has_placeholder_blip_merge(): |
|
blip = img2text.process(task.get_imageUrl()) |
|
prompt = task.PROMPT.merge_blip(blip) |
|
|
|
|
|
prompt = remove_colors(prompt) |
|
|
|
prompt = avatar.add_code_names(prompt) |
|
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) |
|
|
|
if not task.get_style(): |
|
class_name = img_classifier.classify( |
|
task.get_imageUrl(), task.get_width(), task.get_height() |
|
) |
|
else: |
|
class_name = "" |
|
prompt = class_name + " " + prompt |
|
prompt = prompt.strip() |
|
|
|
print({"prompt": prompt}) |
|
|
|
return prompt |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def canny(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
controlnet.load_canny() |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = controlnet.process_canny( |
|
prompt=prompt, |
|
imageUrl=task.get_imageUrl(), |
|
seed=task.get_seed(), |
|
steps=task.get_steps(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
guidance_scale=task.get_cy_guidance_scale(), |
|
negative_prompt=[ |
|
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" |
|
] |
|
* num_return_sequences, |
|
**lora_patcher.kwargs(), |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_canny", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def tile_upscale(task: Task): |
|
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) |
|
|
|
prompt = get_patched_prompt_tile_upscale(task) |
|
|
|
controlnet.load_tile_upscaler() |
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = controlnet.process_tile_upscaler( |
|
imageUrl=task.get_imageUrl(), |
|
seed=task.get_seed(), |
|
steps=task.get_steps(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
prompt=prompt, |
|
resize_dimension=task.get_resize_dimension(), |
|
negative_prompt=task.get_negative_prompt(), |
|
guidance_scale=task.get_ti_guidance_scale(), |
|
) |
|
|
|
generated_image_url = upload_image(images[0], output_key) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_url": generated_image_url, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def scribble(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
controlnet.load_scribble() |
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = controlnet.process_scribble( |
|
imageUrl=task.get_imageUrl(), |
|
seed=task.get_seed(), |
|
steps=task.get_steps(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
prompt=prompt, |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_scribble", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def linearart(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
controlnet.load_linearart() |
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
images, has_nsfw = controlnet.process_linearart( |
|
imageUrl=task.get_imageUrl(), |
|
seed=task.get_seed(), |
|
steps=task.get_steps(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
prompt=prompt, |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_linearart", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
controlnet.load_pose() |
|
|
|
|
|
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
if not task.get_pose_estimation(): |
|
pose = download_image(task.get_imageUrl()).resize( |
|
(task.get_width(), task.get_height()) |
|
) |
|
poses = [pose] * num_return_sequences |
|
elif task.get_pose_coordinates(): |
|
infered_pose = pose_detector.transform( |
|
image=task.get_imageUrl(), |
|
client_coordinates=task.get_pose_coordinates(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
) |
|
poses = [infered_pose] * num_return_sequences |
|
else: |
|
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences |
|
|
|
images, has_nsfw = controlnet.process_pose( |
|
prompt=prompt, |
|
image=poses, |
|
seed=task.get_seed(), |
|
steps=task.get_steps(), |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
guidance_scale=task.get_po_guidance_scale(), |
|
**lora_patcher.kwargs(), |
|
) |
|
|
|
pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId()) |
|
upload_image(poses[0], pose_output_key) |
|
|
|
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
controlnet.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def text2img(task: Task): |
|
params = get_patched_prompt_text2img(task) |
|
|
|
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
images, has_nsfw = text2img_pipe.process( |
|
params=params, |
|
num_inference_steps=task.get_steps(), |
|
guidance_scale=7.5, |
|
height=task.get_height(), |
|
width=task.get_width(), |
|
negative_prompt=task.get_negative_prompt(), |
|
iteration=task.get_iteration(), |
|
**lora_patcher.kwargs(), |
|
) |
|
|
|
generated_image_urls = upload_images(images, "", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
|
|
return { |
|
**params.__dict__, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@auto_clear_cuda_and_gc(controlnet) |
|
@slack.auto_send_alert |
|
def img2img(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style()) |
|
lora_patcher.patch() |
|
|
|
torch.manual_seed(task.get_seed()) |
|
|
|
images, has_nsfw = img2img_pipe.process( |
|
prompt=prompt, |
|
imageUrl=task.get_imageUrl(), |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
steps=task.get_steps(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
strength=task.get_i2i_strength(), |
|
guidance_scale=task.get_i2i_guidance_scale(), |
|
**lora_patcher.kwargs(), |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) |
|
|
|
lora_patcher.cleanup() |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def inpaint(task: Task): |
|
prompt, _ = get_patched_prompt(task) |
|
|
|
print({"prompts": prompt}) |
|
|
|
images = inpainter.process( |
|
prompt=prompt, |
|
image_url=task.get_imageUrl(), |
|
mask_image_url=task.get_maskImageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
seed=task.get_seed(), |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
) |
|
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) |
|
|
|
clear_cuda() |
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
def load_model_by_task(task: Task): |
|
if ( |
|
task.get_type() |
|
in [ |
|
TaskType.TEXT_TO_IMAGE, |
|
TaskType.IMAGE_TO_IMAGE, |
|
TaskType.INPAINT, |
|
] |
|
and not text2img_pipe.is_loaded() |
|
): |
|
text2img_pipe.load(get_model_dir()) |
|
img2img_pipe.create(text2img_pipe) |
|
inpainter.create(text2img_pipe) |
|
|
|
safety_checker.apply(text2img_pipe) |
|
safety_checker.apply(img2img_pipe) |
|
else: |
|
if task.get_type() == TaskType.TILE_UPSCALE: |
|
controlnet.load_tile_upscaler() |
|
elif task.get_type() == TaskType.CANNY: |
|
controlnet.load_canny() |
|
elif task.get_type() == TaskType.SCRIBBLE: |
|
controlnet.load_scribble() |
|
elif task.get_type() == TaskType.LINEARART: |
|
controlnet.load_linearart() |
|
elif task.get_type() == TaskType.POSE: |
|
controlnet.load_pose() |
|
|
|
safety_checker.apply(controlnet) |
|
|
|
|
|
def model_fn(model_dir): |
|
print("Logs: model loaded .... starts") |
|
|
|
set_model_dir(model_dir) |
|
set_root_dir(__file__) |
|
|
|
FailureHandler.register() |
|
|
|
avatar.load_local(model_dir) |
|
|
|
lora_style.load(model_dir) |
|
|
|
print("Logs: model loaded ....") |
|
return |
|
|
|
|
|
@FailureHandler.clear |
|
def predict_fn(data, pipe): |
|
task = Task(data) |
|
print("task is ", data) |
|
|
|
FailureHandler.handle(task) |
|
|
|
try: |
|
|
|
set_configs_from_task(task) |
|
|
|
|
|
load_model_by_task(task) |
|
|
|
|
|
apply_style_args(data) |
|
|
|
|
|
lora_style.fetch_styles() |
|
|
|
|
|
avatar.fetch_from_network(task.get_model_id()) |
|
|
|
task_type = task.get_type() |
|
|
|
if task_type == TaskType.TEXT_TO_IMAGE: |
|
|
|
|
|
|
|
|
|
return text2img(task) |
|
elif task_type == TaskType.IMAGE_TO_IMAGE: |
|
return img2img(task) |
|
elif task_type == TaskType.CANNY: |
|
return canny(task) |
|
elif task_type == TaskType.POSE: |
|
return pose(task) |
|
elif task_type == TaskType.TILE_UPSCALE: |
|
return tile_upscale(task) |
|
elif task_type == TaskType.INPAINT: |
|
return inpaint(task) |
|
elif task_type == TaskType.SCRIBBLE: |
|
return scribble(task) |
|
elif task_type == TaskType.LINEARART: |
|
return linearart(task) |
|
else: |
|
raise Exception("Invalid task type") |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
slack.error_alert(task, e) |
|
controlnet.cleanup() |
|
return None |
|
|