File size: 4,822 Bytes
19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from io import BytesIO
import torch
from internals.data.dataAccessor import update_db
from internals.data.task import ModelType, Task, TaskType
from internals.pipelines.inpainter import InPainter
from internals.pipelines.object_remove import ObjectRemoval
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.remove_background import RemoveBackground
from internals.pipelines.safety_checker import SafetyChecker
from internals.pipelines.upscaler import Upscaler
from internals.util.avatar import Avatar
from internals.util.cache import clear_cuda
from internals.util.commons import (construct_default_s3_url, upload_image,
upload_images)
from internals.util.config import set_configs_from_task, set_root_dir
from internals.util.failure_hander import FailureHandler
from internals.util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
num_return_sequences = 4
auto_mode = False
slack = Slack()
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()
safety_checker = SafetyChecker()
object_removal = ObjectRemoval()
avatar = Avatar()
@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
remove_background = RemoveBackground()
output_image = remove_background.remove(task.get_imageUrl())
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
upload_image(output_image, output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
prompt = avatar.add_code_names(task.get_prompt())
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
print({"prompts": prompt})
images = inpainter.process(
prompt=prompt,
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
width=task.get_width(),
height=task.get_height(),
seed=task.get_seed(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
clear_cuda()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def remove_object(task: Task):
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
images = object_removal.process(
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
seed=task.get_seed(),
width=task.get_width(),
height=task.get_height(),
)
generated_image_urls = upload_image(images[0], output_key)
clear_cuda()
return {"generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
out_img = None
if task.get_modelType() == ModelType.ANIME:
print("Using Anime model")
out_img = upscaler.upscale_anime(
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
)
else:
print("Using Real model")
out_img = upscaler.upscale(
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
)
upload_image(BytesIO(out_img), output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
def model_fn(model_dir):
print("Logs: model loaded .... starts")
set_root_dir(__file__)
FailureHandler.register()
avatar.load_local()
prompt_modifier.load()
safety_checker.load()
object_removal.load(model_dir)
upscaler.load()
inpainter.load()
safety_checker.apply(inpainter)
print("Logs: model loaded ....")
return
@FailureHandler.clear
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
FailureHandler.handle(task)
# Set set_environment
set_configs_from_task(task)
try:
# Set set_environment
set_configs_from_task(task)
# Fetch avatars
avatar.fetch_from_network(task.get_model_id())
task_type = task.get_type()
if task_type == TaskType.REMOVE_BG:
return remove_bg(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.UPSCALE_IMAGE:
return upscale_image(task)
elif task_type == TaskType.OBJECT_REMOVAL:
return remove_object(task)
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
return None
|