import os import traceback from typing import List, Optional import pydash as _ import torch from botocore.vendored.six import BytesIO from numpy import who import internals.util.prompt as prompt_util from internals.data.dataAccessor import update_db, update_db_source_failed from internals.data.task import ModelType, Task, TaskType from internals.pipelines.commons import Img2Img, Text2Img from internals.pipelines.controlnets import ControlNet from internals.pipelines.high_res import HighRes from internals.pipelines.img_classifier import ImageClassifier from internals.pipelines.img_to_text import Image2Text from internals.pipelines.inpainter import InPainter from internals.pipelines.object_remove import ObjectRemoval from internals.pipelines.pose_detector import PoseDetector from internals.pipelines.prompt_modifier import PromptModifier from internals.pipelines.realtime_draw import RealtimeDraw from internals.pipelines.remove_background import RemoveBackgroundV2 from internals.pipelines.replace_background import ReplaceBackground from internals.pipelines.safety_checker import SafetyChecker from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler from internals.pipelines.upscaler import Upscaler from internals.util.args import apply_style_args from internals.util.avatar import Avatar from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc from internals.util.commons import ( base64_to_image, construct_default_s3_url, download_image, image_to_base64, upload_image, upload_images, ) from internals.util.config import ( get_is_sdxl, get_low_gpu_mem, get_model_dir, get_num_return_sequences, set_configs_from_task, set_model_config, set_root_dir, ) from internals.util.failure_hander import FailureHandler from internals.util.lora_style import LoraStyle from internals.util.model_loader import load_model_from_config from internals.util.slack import Slack torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True auto_mode = False prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences()) upscaler = Upscaler() pose_detector = PoseDetector() inpainter = InPainter() high_res = HighRes() img2text = Image2Text() img_classifier = ImageClassifier() object_removal = ObjectRemoval() replace_background = ReplaceBackground() remove_background_v2 = RemoveBackgroundV2() replace_background = ReplaceBackground() controlnet = ControlNet() lora_style = LoraStyle() text2img_pipe = Text2Img() img2img_pipe = Img2Img() safety_checker = SafetyChecker() slack = Slack() avatar = Avatar() realtime_draw = RealtimeDraw() sdxl_tileupscaler = SDXLTileUpscaler() custom_scripts: List = [] def get_patched_prompt(task: Task): return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier) def get_patched_prompt_text2img(task: Task): return prompt_util.get_patched_prompt_text2img( task, avatar, lora_style, prompt_modifier ) def get_patched_prompt_tile_upscale(task: Task): return prompt_util.get_patched_prompt_tile_upscale( task, avatar, lora_style, img_classifier, img2text ) def get_intermediate_dimension(task: Task): if task.get_high_res_fix(): return HighRes.get_intermediate_dimension(task.get_width(), task.get_height()) else: return task.get_width(), task.get_height() @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def canny(task: Task): prompt, _ = get_patched_prompt(task) width, height = get_intermediate_dimension(task) controlnet.load_model("canny") # pipe2 is used for canny and pose lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() ) lora_patcher.patch() kwargs = { "prompt": prompt, "imageUrl": task.get_imageUrl(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "width": width, "height": height, "negative_prompt": [ f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" ] * get_num_return_sequences(), **task.cnc_kwargs(), **lora_patcher.kwargs(), } images, has_nsfw = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) generated_image_urls = upload_images(images, "_canny", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def tile_upscale(task: Task): output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) prompt = get_patched_prompt_tile_upscale(task) if get_is_sdxl(): lora_patcher = lora_style.get_patcher( [sdxl_tileupscaler.pipe, high_res.pipe], task.get_style() ) lora_patcher.patch() images, has_nsfw = sdxl_tileupscaler.process( prompt=prompt, imageUrl=task.get_imageUrl(), resize_dimension=task.get_resize_dimension(), negative_prompt=task.get_negative_prompt(), width=task.get_width(), height=task.get_height(), model_id=task.get_model_id(), ) lora_patcher.cleanup() else: controlnet.load_model("tile_upscaler") lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) lora_patcher.patch() kwargs = { "imageUrl": task.get_imageUrl(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "negative_prompt": task.get_negative_prompt(), "width": task.get_width(), "height": task.get_height(), "prompt": prompt, "resize_dimension": task.get_resize_dimension(), **task.cnt_kwargs(), } images, has_nsfw = controlnet.process(**kwargs) lora_patcher.cleanup() controlnet.cleanup() generated_image_url = upload_image(images[0], output_key) return { "modified_prompts": prompt, "generated_image_url": generated_image_url, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def scribble(task: Task): prompt, _ = get_patched_prompt(task) width, height = get_intermediate_dimension(task) controlnet.load_model("scribble") lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() ) lora_patcher.patch() image = download_image(task.get_imageUrl()).resize((width, height)) if get_is_sdxl(): # We use sketch in SDXL image = ControlNet.pidinet_image(image) else: image = ControlNet.scribble_image(image) kwargs = { "image": [image] * get_num_return_sequences(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "width": width, "height": height, "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), **task.cns_kwargs(), } images, has_nsfw = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) generated_image_urls = upload_images(images, "_scribble", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def linearart(task: Task): prompt, _ = get_patched_prompt(task) width, height = get_intermediate_dimension(task) controlnet.load_model("linearart") lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() ) lora_patcher.patch() kwargs = { "imageUrl": task.get_imageUrl(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "width": width, "height": height, "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), **task.cnl_kwargs(), } images, has_nsfw = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) generated_image_urls = upload_images(images, "_linearart", task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): prompt, _ = get_patched_prompt(task) width, height = get_intermediate_dimension(task) controlnet.load_model("pose") # pipe2 is used for canny and pose lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() ) lora_patcher.patch() if not task.get_pose_estimation(): print("Not detecting pose") pose = download_image(task.get_imageUrl()).resize( (task.get_width(), task.get_height()) ) poses = [pose] * get_num_return_sequences() elif task.get_pose_coordinates(): infered_pose = pose_detector.transform( image=task.get_imageUrl(), client_coordinates=task.get_pose_coordinates(), width=task.get_width(), height=task.get_height(), ) poses = [infered_pose] * get_num_return_sequences() else: poses = [ controlnet.detect_pose(task.get_imageUrl()) ] * get_num_return_sequences() if not get_is_sdxl(): # in normal pipeline we use depth + pose controlnet depth = download_image(task.get_auxilary_imageUrl()).resize( (task.get_width(), task.get_height()) ) depth = ControlNet.depth_image(depth) images = [depth, poses[0]] upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId())) kwargs = { "control_guidance_end": [0.5, 1.0], } else: images = poses[0] kwargs = {} kwargs = { "prompt": prompt, "image": images, "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "width": width, "height": height, **kwargs, **task.cnp_kwargs(), **lora_patcher.kwargs(), } images, has_nsfw = controlnet.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) upload_image(poses[0], "crecoAI/{}_pose.png".format(task.get_taskId())) generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) lora_patcher.cleanup() controlnet.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def text2img(task: Task): params = get_patched_prompt_text2img(task) width, height = get_intermediate_dimension(task) lora_patcher = lora_style.get_patcher( [text2img_pipe.pipe, high_res.pipe], task.get_style() ) lora_patcher.patch() torch.manual_seed(task.get_seed()) kwargs = { "params": params, "num_inference_steps": task.get_steps(), "height": height, "width": width, "negative_prompt": task.get_negative_prompt(), **task.t2i_kwargs(), **lora_patcher.kwargs(), } images, has_nsfw = text2img_pipe.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": params.prompt if params.prompt else [""] * get_num_return_sequences(), "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) generated_image_urls = upload_images(images, "", task.get_taskId()) lora_patcher.cleanup() return { **params.__dict__, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @auto_clear_cuda_and_gc(controlnet) @slack.auto_send_alert def img2img(task: Task): prompt, _ = get_patched_prompt(task) width, height = get_intermediate_dimension(task) torch.manual_seed(task.get_seed()) if get_is_sdxl(): # we run lineart for img2img controlnet.load_model("linearart") lora_patcher = lora_style.get_patcher( [controlnet.pipe2, high_res.pipe], task.get_style() ) lora_patcher.patch() kwargs = { "imageUrl": task.get_imageUrl(), "seed": task.get_seed(), "num_inference_steps": task.get_steps(), "width": width, "height": height, "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), **task.cnl_kwargs(), "adapter_conditioning_scale": 0.3, } images, has_nsfw = controlnet.process(**kwargs) else: lora_patcher = lora_style.get_patcher( [img2img_pipe.pipe, high_res.pipe], task.get_style() ) lora_patcher.patch() kwargs = { "prompt": prompt, "imageUrl": task.get_imageUrl(), "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "num_inference_steps": task.get_steps(), "width": width, "height": height, **task.i2i_kwargs(), **lora_patcher.kwargs(), } images, has_nsfw = img2img_pipe.process(**kwargs) if task.get_high_res_fix(): kwargs = { "prompt": prompt, "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "images": images, "width": task.get_width(), "height": task.get_height(), "num_inference_steps": task.get_steps(), **task.high_res_kwargs(), } images, _ = high_res.apply(**kwargs) generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) lora_patcher.cleanup() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @slack.auto_send_alert def inpaint(task: Task): if task.get_type() == TaskType.OUTPAINT: key = "_outpaint" prompt = [img2text.process(task.get_imageUrl())] * num_return_sequences else: key = "_inpaint" prompt, _ = get_patched_prompt(task) print({"prompts": prompt}) kwargs = { "prompt": prompt, "image_url": task.get_imageUrl(), "mask_image_url": task.get_maskImageUrl(), "width": task.get_width(), "height": task.get_height(), "seed": task.get_seed(), "negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(), "num_inference_steps": task.get_steps(), **task.ip_kwargs(), } images = inpainter.process(**kwargs) generated_image_urls = upload_images(images, key, task.get_taskId()) clear_cuda_and_gc() return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} @update_db @slack.auto_send_alert def replace_bg(task: Task): prompt = task.get_prompt() if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * get_num_return_sequences() lora_patcher = lora_style.get_patcher(replace_background.pipe, task.get_style()) lora_patcher.patch() images, has_nsfw = replace_background.replace( image=task.get_imageUrl(), prompt=prompt, negative_prompt=[task.get_negative_prompt()] * get_num_return_sequences(), seed=task.get_seed(), width=task.get_width(), height=task.get_height(), steps=task.get_steps(), apply_high_res=task.get_high_res_fix(), conditioning_scale=task.rbg_controlnet_conditioning_scale(), model_type=task.get_modelType(), ) generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId()) lora_patcher.cleanup() clear_cuda_and_gc() return { "modified_prompts": prompt, "generated_image_urls": generated_image_urls, "has_nsfw": has_nsfw, } @update_db @slack.auto_send_alert def remove_bg(task: Task): output_image = remove_background_v2.remove( task.get_imageUrl(), model_type=task.get_modelType() ) output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) image_url = upload_image(output_image, output_key) return {"generated_image_url": image_url} @update_db @slack.auto_send_alert def upscale_image(task: Task): output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) out_img = None if ( task.get_modelType() == ModelType.ANIME or task.get_modelType() == ModelType.COMIC ): print("Using Anime model") out_img = upscaler.upscale_anime( image=task.get_imageUrl(), width=task.get_width(), height=task.get_height(), face_enhance=task.get_face_enhance(), resize_dimension=task.get_resize_dimension(), ) else: print("Using Real model") out_img = upscaler.upscale( image=task.get_imageUrl(), width=task.get_width(), height=task.get_height(), face_enhance=task.get_face_enhance(), resize_dimension=task.get_resize_dimension(), ) image_url = upload_image(BytesIO(out_img), output_key) clear_cuda_and_gc() return {"generated_image_url": image_url} @update_db @slack.auto_send_alert def remove_object(task: Task): output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) images = object_removal.process( image_url=task.get_imageUrl(), mask_image_url=task.get_maskImageUrl(), seed=task.get_seed(), width=task.get_width(), height=task.get_height(), ) generated_image_urls = upload_image(images[0], output_key) clear_cuda() return {"generated_image_urls": generated_image_urls} def rt_draw_seg(task: Task): image = task.get_imageUrl() if image.startswith("http"): image = download_image(image) else: # consider image as base64 image = base64_to_image(image) img = realtime_draw.process_seg( image=image, prompt=task.get_prompt(), negative_prompt=task.get_negative_prompt(), seed=task.get_seed(), ) clear_cuda_and_gc() base64_image = image_to_base64(img) return {"image": base64_image} def rt_draw_img(task: Task): image = task.get_imageUrl() aux_image = task.get_auxilary_imageUrl() if image: if image.startswith("http"): image = download_image(image) else: # consider image as base64 image = base64_to_image(image) if aux_image: if aux_image.startswith("http"): aux_image = download_image(aux_image) else: # consider image as base64 aux_image = base64_to_image(aux_image) img = realtime_draw.process_img( image=image, # pyright: ignore image2=aux_image, # pyright: ignore prompt=task.get_prompt(), negative_prompt=task.get_negative_prompt(), seed=task.get_seed(), ) clear_cuda_and_gc() base64_image = image_to_base64(img) return {"image": base64_image} def custom_action(task: Task): from external.scripts import __scripts__ global custom_scripts kwargs = { "CONTROLNET": controlnet, "LORASTYLE": lora_style, } torch.manual_seed(task.get_seed()) for script in __scripts__: script = script.Script(**kwargs) existing_script = _.find( custom_scripts, lambda x: x.__name__ == script.__name__ ) if existing_script: script = existing_script else: custom_scripts.append(script) data = task.get_action_data() if data["name"] == script.__name__: return script(task, data) def load_model_by_task(task_type: TaskType, model_id=-1): if not text2img_pipe.is_loaded(): text2img_pipe.load(get_model_dir()) img2img_pipe.create(text2img_pipe) high_res.load(img2img_pipe) inpainter.init(text2img_pipe) controlnet.init(text2img_pipe) if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT: inpainter.load() safety_checker.apply(inpainter) elif task_type == TaskType.REPLACE_BG: replace_background.load( upscaler=upscaler, base=text2img_pipe, high_res=high_res ) elif task_type == TaskType.RT_DRAW_SEG or task_type == TaskType.RT_DRAW_IMG: realtime_draw.load(text2img_pipe) elif task_type == TaskType.OBJECT_REMOVAL: object_removal.load(get_model_dir()) elif task_type == TaskType.UPSCALE_IMAGE: upscaler.load() else: if task_type == TaskType.TILE_UPSCALE: if get_is_sdxl(): sdxl_tileupscaler.create(high_res, text2img_pipe, model_id) else: controlnet.load_model("tile_upscaler") elif task_type == TaskType.CANNY: controlnet.load_model("canny") elif task_type == TaskType.SCRIBBLE: controlnet.load_model("scribble") elif task_type == TaskType.LINEARART: controlnet.load_model("linearart") elif task_type == TaskType.POSE: controlnet.load_model("pose") def unload_model_by_task(task_type: TaskType): if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT: inpainter.unload() elif task_type == TaskType.REPLACE_BG: replace_background.unload() elif task_type == TaskType.OBJECT_REMOVAL: object_removal.unload() elif task_type == TaskType.TILE_UPSCALE: if get_is_sdxl(): sdxl_tileupscaler.unload() else: controlnet.unload() elif task_type == TaskType.CANNY: controlnet.unload() elif task_type == TaskType.SCRIBBLE: controlnet.unload() elif task_type == TaskType.LINEARART: controlnet.unload() elif task_type == TaskType.POSE: controlnet.unload() def apply_safety_checkers(): safety_checker.apply(text2img_pipe) safety_checker.apply(img2img_pipe) safety_checker.apply(controlnet) def model_fn(model_dir): print("Logs: model loaded .... starts") config = load_model_from_config(model_dir) set_model_config(config) set_root_dir(__file__) FailureHandler.register() avatar.load_local(model_dir) lora_style.load(model_dir) load_model_by_task(TaskType.TEXT_TO_IMAGE) print("Logs: model loaded ....") return def auto_unload_task(func): def wrapper(*args, **kwargs): result = func(*args, **kwargs) if get_low_gpu_mem(): task = Task(args[0]) unload_model_by_task(task.get_type()) # pyright: ignore return result return wrapper @auto_unload_task @FailureHandler.clear def predict_fn(data, pipe): task = Task(data) print("task is ", data) clear_cuda_and_gc() FailureHandler.handle(task) try: task_type = task.get_type() # Set set_environment set_configs_from_task(task) # Load model based on task load_model_by_task( task.get_type() or TaskType.TEXT_TO_IMAGE, task.get_model_id() ) # Apply safety checkers apply_safety_checkers() # Realtime generation apis if task_type == TaskType.RT_DRAW_SEG: return rt_draw_seg(task) if task_type == TaskType.RT_DRAW_IMG: return rt_draw_img(task) # Apply arguments apply_style_args(data) # Re-fetch styles lora_style.fetch_styles() # Fetch avatars avatar.fetch_from_network(task.get_model_id()) if task_type == TaskType.TEXT_TO_IMAGE: return text2img(task) elif task_type == TaskType.IMAGE_TO_IMAGE: return img2img(task) elif task_type == TaskType.CANNY: return canny(task) elif task_type == TaskType.POSE: return pose(task) elif task_type == TaskType.TILE_UPSCALE: return tile_upscale(task) elif task_type == TaskType.INPAINT: return inpaint(task) elif task_type == TaskType.OUTPAINT: return inpaint(task) elif task_type == TaskType.SCRIBBLE: return scribble(task) elif task_type == TaskType.LINEARART: return linearart(task) elif task_type == TaskType.REPLACE_BG: return replace_bg(task) elif task_type == TaskType.CUSTOM_ACTION: return custom_action(task) elif task_type == TaskType.REMOVE_BG: return remove_bg(task) elif task_type == TaskType.UPSCALE_IMAGE: return upscale_image(task) elif task_type == TaskType.OBJECT_REMOVAL: return remove_object(task) elif task_type == TaskType.SYSTEM_CMD: os.system(task.get_prompt()) elif task_type == TaskType.PRELOAD_MODEL: try: task_type = TaskType(task.get_prompt()) except: task_type = TaskType.SYSTEM_CMD load_model_by_task(task_type) else: raise Exception("Invalid task type") except Exception as e: slack.error_alert(task, e) controlnet.cleanup() traceback.print_exc() update_db_source_failed(task.get_sourceId(), task.get_userId()) return None