CM2000112 / inference.py
jayparmr's picture
Upload folder using huggingface_hub
1bc457e
raw
history blame
15.9 kB
import os
from typing import List, Optional
import torch
import internals.util.prompt as prompt_util
from internals.data.dataAccessor import update_db
from internals.data.task import Task, TaskType
from internals.pipelines.commons import Img2Img, Text2Img
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.inpainter import InPainter
from internals.pipelines.pose_detector import PoseDetector
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.safety_checker import SafetyChecker
from internals.util.args import apply_style_args
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
from internals.util.commons import download_image, upload_image, upload_images
from internals.util.config import (
get_model_dir,
num_return_sequences,
set_configs_from_task,
set_model_dir,
set_root_dir,
)
from internals.util.failure_hander import FailureHandler
from internals.util.lora_style import LoraStyle
from internals.util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
auto_mode = False
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
pose_detector = PoseDetector()
inpainter = InPainter()
high_res = HighRes()
img2text = Image2Text()
img_classifier = ImageClassifier()
controlnet = ControlNet()
lora_style = LoraStyle()
text2img_pipe = Text2Img()
img2img_pipe = Img2Img()
safety_checker = SafetyChecker()
slack = Slack()
avatar = Avatar()
def get_patched_prompt(task: Task):
return prompt_util.get_patched_prompt(task, avatar, lora_style, prompt_modifier)
def get_patched_prompt_text2img(task: Task):
return prompt_util.get_patched_prompt_text2img(
task, avatar, lora_style, prompt_modifier
)
def get_patched_prompt_tile_upscale(task: Task):
return prompt_util.get_patched_prompt_tile_upscale(
task, avatar, lora_style, img_classifier, img2text
)
def get_intermediate_dimension(task: Task):
if task.get_high_res_fix():
return HighRes.get_intermediate_dimension(task.get_width(), task.get_height())
else:
return task.get_width(), task.get_height()
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def canny(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_canny()
# pipe2 is used for canny and pose
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
lora_patcher.patch()
images, has_nsfw = controlnet.process_canny(
prompt=prompt,
imageUrl=task.get_imageUrl(),
seed=task.get_seed(),
steps=task.get_steps(),
width=width,
height=height,
guidance_scale=task.get_cy_guidance_scale(),
negative_prompt=[
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
]
* num_return_sequences,
**lora_patcher.kwargs(),
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def tile_upscale(task: Task):
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
prompt = get_patched_prompt_tile_upscale(task)
controlnet.load_tile_upscaler()
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
lora_patcher.patch()
images, has_nsfw = controlnet.process_tile_upscaler(
imageUrl=task.get_imageUrl(),
seed=task.get_seed(),
steps=task.get_steps(),
width=task.get_width(),
height=task.get_height(),
prompt=prompt,
resize_dimension=task.get_resize_dimension(),
negative_prompt=task.get_negative_prompt(),
guidance_scale=task.get_ti_guidance_scale(),
)
generated_image_url = upload_image(images[0], output_key)
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_url": generated_image_url,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def scribble(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_scribble()
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
lora_patcher.patch()
images, has_nsfw = controlnet.process_scribble(
imageUrl=task.get_imageUrl(),
seed=task.get_seed(),
steps=task.get_steps(),
width=width,
height=height,
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def linearart(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_linearart()
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
lora_patcher.patch()
images, has_nsfw = controlnet.process_linearart(
imageUrl=task.get_imageUrl(),
seed=task.get_seed(),
steps=task.get_steps(),
width=width,
height=height,
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
controlnet.load_pose()
# pipe2 is used for canny and pose
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
lora_patcher.patch()
if not task.get_pose_estimation():
pose = download_image(task.get_imageUrl()).resize(
(task.get_width(), task.get_height())
)
poses = [pose] * num_return_sequences
elif task.get_pose_coordinates():
infered_pose = pose_detector.transform(
image=task.get_imageUrl(),
client_coordinates=task.get_pose_coordinates(),
width=task.get_width(),
height=task.get_height(),
)
poses = [infered_pose] * num_return_sequences
else:
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
images, has_nsfw = controlnet.process_pose(
prompt=prompt,
image=poses,
seed=task.get_seed(),
steps=task.get_steps(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
width=width,
height=height,
guidance_scale=task.get_po_guidance_scale(),
**lora_patcher.kwargs(),
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId())
upload_image(poses[0], pose_output_key)
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
lora_patcher.cleanup()
controlnet.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def text2img(task: Task):
params = get_patched_prompt_text2img(task)
width, height = get_intermediate_dimension(task)
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
lora_patcher.patch()
torch.manual_seed(task.get_seed())
images, has_nsfw = text2img_pipe.process(
params=params,
num_inference_steps=task.get_steps(),
guidance_scale=7.5,
height=height,
width=width,
negative_prompt=task.get_negative_prompt(),
iteration=task.get_iteration(),
**lora_patcher.kwargs(),
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=params.prompt if params.prompt else [""] * num_return_sequences,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "", task.get_taskId())
lora_patcher.cleanup()
return {
**params.__dict__,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def img2img(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
lora_patcher.patch()
torch.manual_seed(task.get_seed())
images, has_nsfw = img2img_pipe.process(
prompt=prompt,
imageUrl=task.get_imageUrl(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
steps=task.get_steps(),
width=width,
height=height,
strength=task.get_i2i_strength(),
guidance_scale=task.get_i2i_guidance_scale(),
**lora_patcher.kwargs(),
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
lora_patcher.cleanup()
return {
"modified_prompts": prompt,
"generated_image_urls": generated_image_urls,
"has_nsfw": has_nsfw,
}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
prompt, _ = get_patched_prompt(task)
width, height = get_intermediate_dimension(task)
print({"prompts": prompt})
images = inpainter.process(
prompt=prompt,
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
width=width,
height=height,
seed=task.get_seed(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
if task.get_high_res_fix():
images, _ = high_res.apply(
prompt=prompt,
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
images=images,
width=task.get_width(),
height=task.get_height(),
steps=task.get_steps(),
)
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
clear_cuda()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
def load_model_by_task(task: Task):
if (
task.get_type()
in [
TaskType.TEXT_TO_IMAGE,
TaskType.IMAGE_TO_IMAGE,
TaskType.INPAINT,
]
and not text2img_pipe.is_loaded()
):
text2img_pipe.load(get_model_dir())
img2img_pipe.create(text2img_pipe)
inpainter.create(text2img_pipe)
high_res.load(img2img_pipe)
safety_checker.apply(text2img_pipe)
safety_checker.apply(img2img_pipe)
else:
if task.get_type() == TaskType.TILE_UPSCALE:
controlnet.load_tile_upscaler()
elif task.get_type() == TaskType.CANNY:
controlnet.load_canny()
elif task.get_type() == TaskType.SCRIBBLE:
controlnet.load_scribble()
elif task.get_type() == TaskType.LINEARART:
controlnet.load_linearart()
elif task.get_type() == TaskType.POSE:
controlnet.load_pose()
high_res.load()
safety_checker.apply(controlnet)
def model_fn(model_dir):
print("Logs: model loaded .... starts")
set_model_dir(model_dir)
set_root_dir(__file__)
FailureHandler.register()
avatar.load_local(model_dir)
lora_style.load(model_dir)
print("Logs: model loaded ....")
return
@FailureHandler.clear
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
FailureHandler.handle(task)
try:
# Set set_environment
set_configs_from_task(task)
# Load model based on task
load_model_by_task(task)
# Apply arguments
apply_style_args(data)
# Re-fetch styles
lora_style.fetch_styles()
# Fetch avatars
avatar.fetch_from_network(task.get_model_id())
task_type = task.get_type()
if task_type == TaskType.TEXT_TO_IMAGE:
# character sheet
# if "character sheet" in task.get_prompt().lower():
# return pose(task, s3_outkey="", poses=pickPoses())
# else:
return text2img(task)
elif task_type == TaskType.IMAGE_TO_IMAGE:
return img2img(task)
elif task_type == TaskType.CANNY:
return canny(task)
elif task_type == TaskType.POSE:
return pose(task)
elif task_type == TaskType.TILE_UPSCALE:
return tile_upscale(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.SCRIBBLE:
return scribble(task)
elif task_type == TaskType.LINEARART:
return linearart(task)
elif task_type == TaskType.SYSTEM_CMD:
os.system(task.get_prompt())
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
controlnet.cleanup()
return None