|
from io import BytesIO |
|
|
|
import torch |
|
|
|
from internals.data.dataAccessor import update_db |
|
from internals.data.task import ModelType, Task, TaskType |
|
from internals.pipelines.inpainter import InPainter |
|
from internals.pipelines.object_remove import ObjectRemoval |
|
from internals.pipelines.prompt_modifier import PromptModifier |
|
from internals.pipelines.remove_background import (RemoveBackground, |
|
RemoveBackgroundV2) |
|
from internals.pipelines.replace_background import ReplaceBackground |
|
from internals.pipelines.safety_checker import SafetyChecker |
|
from internals.pipelines.upscaler import Upscaler |
|
from internals.util.avatar import Avatar |
|
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda |
|
from internals.util.commons import (construct_default_s3_url, upload_image, |
|
upload_images) |
|
from internals.util.config import (num_return_sequences, set_configs_from_task, |
|
set_root_dir) |
|
from internals.util.failure_hander import FailureHandler |
|
from internals.util.slack import Slack |
|
|
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
auto_mode = False |
|
|
|
slack = Slack() |
|
|
|
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) |
|
upscaler = Upscaler() |
|
inpainter = InPainter() |
|
safety_checker = SafetyChecker() |
|
object_removal = ObjectRemoval() |
|
remove_background_v2 = RemoveBackgroundV2() |
|
avatar = Avatar() |
|
replace_background = ReplaceBackground() |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def remove_bg(task: Task): |
|
|
|
output_image = remove_background_v2.remove(task.get_imageUrl()) |
|
|
|
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) |
|
upload_image(output_image, output_key) |
|
|
|
return {"generated_image_url": construct_default_s3_url(output_key)} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def inpaint(task: Task): |
|
prompt = avatar.add_code_names(task.get_prompt()) |
|
if task.is_prompt_engineering(): |
|
prompt = prompt_modifier.modify(prompt) |
|
else: |
|
prompt = [prompt] * num_return_sequences |
|
|
|
print({"prompts": prompt}) |
|
|
|
images = inpainter.process( |
|
prompt=prompt, |
|
image_url=task.get_imageUrl(), |
|
mask_image_url=task.get_maskImageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
seed=task.get_seed(), |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
) |
|
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) |
|
|
|
clear_cuda() |
|
|
|
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def remove_object(task: Task): |
|
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) |
|
|
|
images = object_removal.process( |
|
image_url=task.get_imageUrl(), |
|
mask_image_url=task.get_maskImageUrl(), |
|
seed=task.get_seed(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
) |
|
generated_image_urls = upload_image(images[0], output_key) |
|
|
|
clear_cuda() |
|
|
|
return {"generated_image_urls": generated_image_urls} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def replace_bg(task: Task): |
|
prompt = task.get_prompt() |
|
if task.is_prompt_engineering(): |
|
prompt = prompt_modifier.modify(prompt) |
|
else: |
|
prompt = [prompt] * num_return_sequences |
|
|
|
images, has_nsfw = replace_background.replace( |
|
image=task.get_imageUrl(), |
|
prompt=prompt, |
|
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, |
|
seed=task.get_seed(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
steps=task.get_steps(), |
|
resize_dimension=task.get_resize_dimension(), |
|
product_scale_width=task.get_image_scale(), |
|
) |
|
|
|
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId()) |
|
|
|
return { |
|
"modified_prompts": prompt, |
|
"generated_image_urls": generated_image_urls, |
|
"has_nsfw": has_nsfw, |
|
} |
|
|
|
|
|
@update_db |
|
@slack.auto_send_alert |
|
def upscale_image(task: Task): |
|
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) |
|
out_img = None |
|
if task.get_modelType() == ModelType.ANIME: |
|
print("Using Anime model") |
|
out_img = upscaler.upscale_anime( |
|
image=task.get_imageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
face_enhance=task.get_face_enhance(), |
|
resize_dimension=task.get_resize_dimension(), |
|
) |
|
else: |
|
print("Using Real model") |
|
out_img = upscaler.upscale( |
|
image=task.get_imageUrl(), |
|
width=task.get_width(), |
|
height=task.get_height(), |
|
face_enhance=task.get_face_enhance(), |
|
resize_dimension=task.get_resize_dimension(), |
|
) |
|
|
|
upload_image(BytesIO(out_img), output_key) |
|
return {"generated_image_url": construct_default_s3_url(output_key)} |
|
|
|
|
|
def model_fn(model_dir): |
|
print("Logs: model loaded .... starts") |
|
|
|
set_root_dir(__file__) |
|
|
|
FailureHandler.register() |
|
|
|
avatar.load_local(model_dir) |
|
|
|
prompt_modifier.load() |
|
safety_checker.load() |
|
|
|
object_removal.load(model_dir) |
|
upscaler.load() |
|
inpainter.load() |
|
|
|
replace_background.load(upscaler, remove_background_v2) |
|
|
|
print("Logs: model loaded ....") |
|
return |
|
|
|
|
|
@FailureHandler.clear |
|
def predict_fn(data, pipe): |
|
task = Task(data) |
|
print("task is ", data) |
|
|
|
FailureHandler.handle(task) |
|
|
|
try: |
|
|
|
set_configs_from_task(task) |
|
|
|
|
|
safety_checker.apply(inpainter) |
|
|
|
|
|
avatar.fetch_from_network(task.get_model_id()) |
|
|
|
task_type = task.get_type() |
|
|
|
if task_type == TaskType.REMOVE_BG: |
|
return remove_bg(task) |
|
elif task_type == TaskType.INPAINT: |
|
return inpaint(task) |
|
elif task_type == TaskType.UPSCALE_IMAGE: |
|
return upscale_image(task) |
|
elif task_type == TaskType.OBJECT_REMOVAL: |
|
return remove_object(task) |
|
elif task_type == TaskType.REPLACE_BG: |
|
return replace_bg(task) |
|
else: |
|
raise Exception("Invalid task type") |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
slack.error_alert(task, e) |
|
return None |
|
|