from typing import List, Optional import torch from internals.data.dataAccessor import update_db from internals.data.task import Task, TaskType from internals.pipelines.commons import Img2Img, Text2Img from internals.pipelines.controlnets import ControlNet from internals.pipelines.img_classifier import ImageClassifier from internals.pipelines.img_to_text import Image2Text from internals.pipelines.inpainter import InPainter from internals.pipelines.pose_detector import PoseDetector from internals.pipelines.prompt_modifier import PromptModifier from internals.pipelines.safety_checker import SafetyChecker from internals.util.args import apply_style_args from internals.util.avatar import Avatar from internals.util.cache import (auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc) from internals.util.commons import pickPoses, upload_image, upload_images from internals.util.config import (num_return_sequences, set_configs_from_task, set_root_dir) from internals.util.failure_hander import FailureHandler from internals.util.lora_style import LoraStyle from internals.util.slack import Slack torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True auto_mode = False prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) pose_detector = PoseDetector() inpainter = InPainter() img2text = Image2Text() img_classifier = ImageClassifier() controlnet = ControlNet() lora_style = LoraStyle() text2img_pipe = Text2Img() img2img_pipe = Img2Img() safety_checker = SafetyChecker() slack = Slack() avatar = Avatar() def get_patched_prompt(task: Task): def add_style_and_character(prompt: List[str], additional: Optional[str] = None): for i in range(len(prompt)): prompt[i] = avatar.add_code_names(prompt[i]) prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) if additional: prompt[i] = additional + " " + prompt[i] prompt = task.get_prompt() if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * num_return_sequences ori_prompt = [task.get_prompt()] * num_return_sequences class_name = None add_style_and_character(ori_prompt, class_name) add_style_and_character(prompt, class_name) print({"prompts": prompt}) return (prompt, ori_prompt) def get_patched_prompt_text2img(task: Task) -> Text2Img.Params: def add_style_and_character(prompt: str, prepend: str = ""): prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) prompt = prepend + prompt return prompt if task.get_prompt_left() and task.get_prompt_right(): # prepend = "2characters, " prepend = "" if task.is_prompt_engineering(): mod_prompt = prompt_modifier.modify(task.get_prompt()) else: mod_prompt = [task.get_prompt()] * num_return_sequences prompt, prompt_left, prompt_right = [], [], [] for i in range(len(mod_prompt)): mp = mod_prompt[i].replace(task.get_prompt(), "") prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp) prompt_left.append( add_style_and_character(task.get_prompt_left(), prepend) + mp ) prompt_right.append( add_style_and_character(task.get_prompt_right(), prepend) + mp ) params = Text2Img.Params( prompt=prompt, prompt_left=prompt_left, prompt_right=prompt_right, ) else: if task.is_prompt_engineering(): mod_prompt = prompt_modifier.modify(task.get_prompt()) else: mod_prompt = [task.get_prompt()] * num_return_sequences mod_prompt = [add_style_and_character(mp) for mp in mod_prompt] params = Text2Img.Params( prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences, modified_prompt=mod_prompt, ) print(params) return params def get_patched_prompt_tile_upscale(task: Task): if task.get_prompt(): prompt = task.get_prompt() else: prompt = img2text.process(task.get_imageUrl()) prompt = avatar.add_code_names(prompt) prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) if not task.get_style(): class_name = img_classifier.classify( task.get_imageUrl(), task.get_width(), task.get_height() ) else: class_name = "" prompt = class_name + " " + prompt prompt = prompt.strip() print({"prompt": prompt}) return prompt @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def canny(task: Task): prompt, _ = get_patched_prompt(task) controlnet.load_canny() # pipe2 is used for canny and pose lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) lora_patcher.patch() images, has_nsfw = controlnet.process_canny( prompt=prompt, imageUrl=task.get_imageUrl(), seed=task.get_seed(), steps=task.get_steps(), width=task.get_width(), height=task.get_height(), guidance_scale=task.get_cy_guidance_scale(), negative_prompt=[ f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" ] * num_return_sequences, **lora_patcher.kwargs(), ) generated_image_urls = upload_images(images, "_canny", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def tile_upscale(task: Task): output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) prompt = get_patched_prompt_tile_upscale(task) controlnet.load_tile_upscaler() lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) lora_patcher.patch() images, has_nsfw = controlnet.process_tile_upscaler( imageUrl=task.get_imageUrl(), seed=task.get_seed(), steps=task.get_steps(), width=task.get_width(), height=task.get_height(), prompt=prompt, resize_dimension=task.get_resize_dimension(), negative_prompt=task.get_negative_prompt(), guidance_scale=task.get_ti_guidance_scale(), ) generated_image_url = upload_image(images[0], output_key) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_url": generated_image_url, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def scribble(task: Task): prompt, _ = get_patched_prompt(task) controlnet.load_scribble() lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) lora_patcher.patch() images, has_nsfw = controlnet.process_scribble( imageUrl=task.get_imageUrl(), seed=task.get_seed(), steps=task.get_steps(), width=task.get_width(), height=task.get_height(), prompt=prompt, negative_prompt=[task.get_negative_prompt()] * num_return_sequences, ) generated_image_urls = upload_images(images, "_scribble", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def linearart(task: Task): prompt, _ = get_patched_prompt(task) controlnet.load_linearart() lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) lora_patcher.patch() images, has_nsfw = controlnet.process_linearart( imageUrl=task.get_imageUrl(), seed=task.get_seed(), steps=task.get_steps(), width=task.get_width(), height=task.get_height(), prompt=prompt, negative_prompt=[task.get_negative_prompt()] * num_return_sequences, ) generated_image_urls = upload_images(images, "_linearart", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): prompt, _ = get_patched_prompt(task) controlnet.load_pose() # pipe2 is used for canny and pose lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) lora_patcher.patch() if task.get_pose_coordinates(): infered_pose = pose_detector.transform( image=task.get_imageUrl(), client_coordinates=task.get_pose_coordinates(), width=task.get_width(), height=task.get_height(), ) poses = [infered_pose] * num_return_sequences else: poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences images, has_nsfw = controlnet.process_pose( prompt=prompt, image=poses, seed=task.get_seed(), steps=task.get_steps(), negative_prompt=[task.get_negative_prompt()] * num_return_sequences, width=task.get_width(), height=task.get_height(), guidance_scale=task.get_po_guidance_scale(), **lora_patcher.kwargs(), ) pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId()) upload_image(poses[0], pose_output_key) generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def text2img(task: Task): params = get_patched_prompt_text2img(task) lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style()) lora_patcher.patch() torch.manual_seed(task.get_seed()) images, has_nsfw = text2img_pipe.process( params=params, num_inference_steps=task.get_steps(), guidance_scale=7.5, height=task.get_height(), width=task.get_width(), negative_prompt=task.get_negative_prompt(), iteration=task.get_iteration(), **lora_patcher.kwargs(), ) generated_image_urls = upload_images(images, "", task.get_taskId()) lora_patcher.cleanup() return { **params.__dict__, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def img2img(task: Task): prompt, _ = get_patched_prompt(task) lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style()) lora_patcher.patch() torch.manual_seed(task.get_seed()) images, has_nsfw = img2img_pipe.process( prompt=prompt, imageUrl=task.get_imageUrl(), negative_prompt=[task.get_negative_prompt()] * num_return_sequences, steps=task.get_steps(), width=task.get_width(), height=task.get_height(), strength=task.get_i2i_strength(), guidance_scale=task.get_i2i_guidance_scale(), **lora_patcher.kwargs(), ) generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) lora_patcher.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @slack.auto_send_alert def inpaint(task: Task): prompt, _ = get_patched_prompt(task) print({"prompts": prompt}) images = inpainter.process( prompt=prompt, image_url=task.get_imageUrl(), mask_image_url=task.get_maskImageUrl(), width=task.get_width(), height=task.get_height(), seed=task.get_seed(), negative_prompt=[task.get_negative_prompt()] * num_return_sequences, ) generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) clear_cuda() return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} def model_fn(model_dir): print("Logs: model loaded .... starts") set_root_dir(__file__) FailureHandler.register() avatar.load_local(model_dir) prompt_modifier.load() pose_detector.load() img2text.load() img_classifier.load() lora_style.load(model_dir) safety_checker.load() controlnet.load(model_dir) text2img_pipe.load(model_dir) img2img_pipe.create(text2img_pipe) inpainter.create(text2img_pipe) safety_checker.apply(text2img_pipe) safety_checker.apply(img2img_pipe) safety_checker.apply(controlnet) print("Logs: model loaded ....") return @FailureHandler.clear def predict_fn(data, pipe): task = Task(data) print("task is ", data) FailureHandler.handle(task) try: # Set set_environment set_configs_from_task(task) # Apply arguments apply_style_args(data) # Re-fetch styles lora_style.fetch_styles() # Fetch avatars avatar.fetch_from_network(task.get_model_id()) task_type = task.get_type() if task_type == TaskType.TEXT_TO_IMAGE: # character sheet if "character sheet" in task.get_prompt().lower(): return pose(task, s3_outkey="", poses=pickPoses()) else: return text2img(task) elif task_type == TaskType.IMAGE_TO_IMAGE: return img2img(task) elif task_type == TaskType.CANNY: return canny(task) elif task_type == TaskType.POSE: return pose(task) elif task_type == TaskType.TILE_UPSCALE: return tile_upscale(task) elif task_type == TaskType.INPAINT: return inpaint(task) elif task_type == TaskType.SCRIBBLE: return scribble(task) elif task_type == TaskType.LINEARART: return linearart(task) else: raise Exception("Invalid task type") except Exception as e: print(f"Error: {e}") slack.error_alert(task, e) controlnet.cleanup() return None