import os from io import BytesIO import torch import internals.util.prompt as prompt_util from internals.data.dataAccessor import update_db, update_db_source_failed from internals.data.task import ModelType, Task, TaskType from internals.pipelines.controlnets import ControlNet from internals.pipelines.high_res import HighRes from internals.pipelines.img_classifier import ImageClassifier from internals.pipelines.img_to_text import Image2Text from internals.pipelines.inpainter import InPainter from internals.pipelines.object_remove import ObjectRemoval from internals.pipelines.prompt_modifier import PromptModifier from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2 from internals.pipelines.replace_background import ReplaceBackground from internals.pipelines.safety_checker import SafetyChecker from internals.pipelines.upscaler import Upscaler from internals.util.avatar import Avatar from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda from internals.util.commons import construct_default_s3_url, upload_image, upload_images from internals.util.config import ( num_return_sequences, set_configs_from_task, set_model_config, set_root_dir, ) from internals.util.failure_hander import FailureHandler from internals.util.lora_style import LoraStyle from internals.util.model_loader import load_model_from_config from internals.util.slack import Slack torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True auto_mode = False slack = Slack() prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) upscaler = Upscaler() inpainter = InPainter() controlnet = ControlNet() safety_checker = SafetyChecker() high_res = HighRes() object_removal = ObjectRemoval() remove_background_v2 = RemoveBackgroundV2() replace_background = ReplaceBackground() img2text = Image2Text() img_classifier = ImageClassifier() avatar = Avatar() lora_style = LoraStyle() def get_patched_prompt_tile_upscale(task: Task): return prompt_util.get_patched_prompt_tile_upscale( task, avatar, lora_style, img_classifier, img2text ) @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def tile_upscale(task: Task): output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) prompt = get_patched_prompt_tile_upscale(task) controlnet.load_model("tile_upscaler") lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) lora_patcher.patch() kwargs = { "imageUrl": task.get_imageUrl(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "negative_prompt": task.get_negative_prompt(), "width": task.get_width(), "height": task.get_height(), "prompt": prompt, "resize_dimension": task.get_resize_dimension(), **task.cnt_kwargs(), } images, has_nsfw = controlnet.process(**kwargs) generated_image_url = upload_image(images[0], output_key) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_url": generated_image_url, "has_nsfw": has_nsfw, } @update_db @slack.auto_send_alert def remove_bg(task: Task): output_image = remove_background_v2.remove( task.get_imageUrl(), model_type=task.get_modelType() ) output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) upload_image(output_image, output_key) return {"generated_image_url": construct_default_s3_url(output_key)} @update_db @slack.auto_send_alert def inpaint(task: Task): prompt = avatar.add_code_names(task.get_prompt()) if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * num_return_sequences print({"prompts": prompt}) kwargs = { "prompt": prompt, "image_url": task.get_imageUrl(), "mask_image_url": task.get_maskImageUrl(), "width": task.get_width(), "height": task.get_height(), "seed": task.get_seed(), "negative_prompt": [task.get_negative_prompt()] * num_return_sequences, "num_inference_steps": task.get_steps(), **task.ip_kwargs(), } images = inpainter.process(**kwargs) generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) clear_cuda() return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} @update_db @slack.auto_send_alert def remove_object(task: Task): output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) images = object_removal.process( image_url=task.get_imageUrl(), mask_image_url=task.get_maskImageUrl(), seed=task.get_seed(), width=task.get_width(), height=task.get_height(), ) generated_image_urls = upload_image(images[0], output_key) clear_cuda() return {"generated_image_urls": generated_image_urls} @update_db @slack.auto_send_alert def replace_bg(task: Task): prompt = task.get_prompt() if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * num_return_sequences images, has_nsfw = replace_background.replace( image=task.get_imageUrl(), prompt=prompt, negative_prompt=[task.get_negative_prompt()] * num_return_sequences, seed=task.get_seed(), width=task.get_width(), height=task.get_height(), steps=task.get_steps(), extend_object=task.rbg_extend_object(), product_scale_width=task.get_image_scale(), conditioning_scale=task.rbg_controlnet_conditioning_scale(), model_type=task.get_modelType(), ) generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId()) return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @slack.auto_send_alert def upscale_image(task: Task): output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) out_img = None if ( task.get_modelType() == ModelType.ANIME or task.get_modelType() == ModelType.COMIC ): print("Using Anime model") out_img = upscaler.upscale_anime( image=task.get_imageUrl(), width=task.get_width(), height=task.get_height(), face_enhance=task.get_face_enhance(), resize_dimension=task.get_resize_dimension(), ) else: print("Using Real model") out_img = upscaler.upscale( image=task.get_imageUrl(), width=task.get_width(), height=task.get_height(), face_enhance=task.get_face_enhance(), resize_dimension=task.get_resize_dimension(), ) upload_image(BytesIO(out_img), output_key) return {"generated_image_url": construct_default_s3_url(output_key)} def model_fn(model_dir): print("Logs: model loaded .... starts") config = load_model_from_config(model_dir) set_model_config(config) set_root_dir(__file__) FailureHandler.register() avatar.load_local(model_dir) lora_style.load(model_dir) prompt_modifier.load() safety_checker.load() object_removal.load(model_dir) upscaler.load() inpainter.load() high_res.load() controlnet.init(high_res) replace_background.load( upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res ) print("Logs: model loaded ....") return def load_model_by_task(task: Task): if task.get_type() == TaskType.TILE_UPSCALE: controlnet.load_model("tile_upscaler") safety_checker.apply(controlnet) @FailureHandler.clear def predict_fn(data, pipe): task = Task(data) print("task is ", data) FailureHandler.handle(task) try: # Set set_environment set_configs_from_task(task) # Load model based on task load_model_by_task(task) # Apply safety checker based on environment safety_checker.apply(inpainter) safety_checker.apply(replace_background) safety_checker.apply(high_res) # Fetch avatars avatar.fetch_from_network(task.get_model_id()) task_type = task.get_type() if task_type == TaskType.REMOVE_BG: return remove_bg(task) elif task_type == TaskType.INPAINT: return inpaint(task) elif task_type == TaskType.UPSCALE_IMAGE: return upscale_image(task) elif task_type == TaskType.OBJECT_REMOVAL: return remove_object(task) elif task_type == TaskType.REPLACE_BG: return replace_bg(task) elif task_type == TaskType.TILE_UPSCALE: return tile_upscale(task) elif task_type == TaskType.SYSTEM_CMD: os.system(task.get_prompt()) else: raise Exception("Invalid task type") except Exception as e: print(f"Error: {e}") slack.error_alert(task, e) controlnet.cleanup() update_db_source_failed(task.get_sourceId(), task.get_userId()) return None