|
from enum import Enum |
|
from functools import lru_cache |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
|
|
|
|
class TaskType(Enum): |
|
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE" |
|
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE" |
|
POSE = "POSE" |
|
CANNY = "CANNY" |
|
REMOVE_BG = "REMOVE_BG" |
|
INPAINT = "INPAINT" |
|
UPSCALE_IMAGE = "UPSCALE_IMAGE" |
|
TILE_UPSCALE = "TILE_UPSCALE" |
|
OBJECT_REMOVAL = "OBJECT_REMOVAL" |
|
SCRIBBLE = "SCRIBBLE" |
|
LINEARART = "LINEARART" |
|
REPLACE_BG = "REPLACE_BG" |
|
RT_DRAW_SEG = "RT_DRAW_SEG" |
|
RT_DRAW_IMG = "RT_DRAW_IMG" |
|
PRELOAD_MODEL = "PRELOAD_MODEL" |
|
CUSTOM_ACTION = "CUSTOM_ACTION" |
|
SYSTEM_CMD = "SYSTEM_CMD" |
|
OUTPAINT = "OUTPAINT" |
|
|
|
|
|
class ModelType(Enum): |
|
REAL = 10000 |
|
ANIME = 10001 |
|
COMIC = 10002 |
|
|
|
@classmethod |
|
def _missing_(cls, value): |
|
return cls.REAL |
|
|
|
|
|
class Task: |
|
def __init__(self, data): |
|
self.__data = data |
|
if data.get("seed", -1) == None or self.get_seed() == -1: |
|
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max) |
|
prompt = data.get("prompt", "") |
|
if prompt is None: |
|
self.__data["prompt"] = "" |
|
elif len(prompt) > 200: |
|
self.__data["prompt"] = data.get("prompt", "")[:200] + ", " |
|
|
|
def get_taskId(self) -> str: |
|
return self.__data.get("task_id") |
|
|
|
def get_sourceId(self) -> str: |
|
return self.__data.get("source_id") |
|
|
|
def get_imageUrl(self) -> str: |
|
return self.__data.get("imageUrl", None) |
|
|
|
def get_auxilary_imageUrl(self) -> Optional[str]: |
|
return self.__data.get("aux_imageUrl", None) |
|
|
|
def get_prompt(self) -> str: |
|
return self.__data.get("prompt", "") |
|
|
|
def get_prompt_left(self) -> str: |
|
return self.__data.get("prompt_left", "") |
|
|
|
def get_prompt_right(self) -> str: |
|
return self.__data.get("prompt_right", "") |
|
|
|
def get_userId(self) -> str: |
|
return self.__data.get("userId", "") |
|
|
|
def get_email(self) -> str: |
|
return self.__data.get("email", "") |
|
|
|
def get_style(self) -> str: |
|
return self.__data.get("style", None) |
|
|
|
def get_iteration(self) -> float: |
|
return float(self.__data.get("iteration", 3.0)) |
|
|
|
def get_modelType(self) -> ModelType: |
|
id = self.get_model_id() |
|
return ModelType(id) |
|
|
|
def get_model_id(self) -> int: |
|
return int(self.__data.get("modelId", 10000)) |
|
|
|
def get_width(self) -> int: |
|
return int(self.__data.get("width", 512)) |
|
|
|
def get_height(self) -> int: |
|
return int(self.__data.get("height", 512)) |
|
|
|
def get_seed(self) -> int: |
|
return int(self.__data.get("seed", -1)) |
|
|
|
def get_steps(self) -> int: |
|
return int(self.__data.get("steps", 30)) |
|
|
|
def get_type(self) -> Union[TaskType, None]: |
|
try: |
|
return TaskType(self.__data.get("task_type")) |
|
except ValueError: |
|
return None |
|
|
|
def get_maskImageUrl(self) -> str: |
|
return self.__data.get("maskImageUrl") |
|
|
|
def get_pose_coordinates(self) -> dict: |
|
return self.__data.get("pose_coordinates", None) |
|
|
|
def get_pose_estimation(self) -> bool: |
|
return self.__data.get("pose_estimation", True) |
|
|
|
def get_negative_prompt(self) -> str: |
|
return self.__data.get("negative_prompt", "") |
|
|
|
def is_prompt_engineering(self) -> bool: |
|
return self.__data.get("auto_mode", True) |
|
|
|
def get_image_scale(self) -> float: |
|
return self.__data.get("image_scale", 1.0) |
|
|
|
def get_queue_name(self) -> str: |
|
return self.__data.get("queue_name", "") |
|
|
|
def get_resize_dimension(self) -> int: |
|
return self.__data.get("resize_dimension", 1024) |
|
|
|
def get_face_enhance(self) -> bool: |
|
return self.__data.get("up_face_enhance", False) |
|
|
|
def rbg_controlnet_conditioning_scale(self) -> float: |
|
return self.__data.get("rbg_conditioning_scale", 0.5) |
|
|
|
def rbg_extend_object(self) -> bool: |
|
return self.__data.get("rbg_extend_object", False) |
|
|
|
def get_nsfw_threshold(self) -> float: |
|
return self.__data.get("nsfw_threshold", 0.03) |
|
|
|
def get_num_return_sequences(self) -> int: |
|
return self.__data.get("num_return_sequences", 4) |
|
|
|
def can_access_nsfw(self) -> bool: |
|
return self.__data.get("can_access_nsfw", False) |
|
|
|
def get_access_token(self) -> str: |
|
return self.__data.get("access_token", "") |
|
|
|
def get_high_res_fix(self) -> bool: |
|
return self.__data.get("high_res_fix", False) |
|
|
|
def get_base_dimension(self): |
|
return self.__data.get("base_dimension", None) |
|
|
|
def get_action_data(self) -> dict: |
|
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key" |
|
return self.__data.get("action_data", {}) |
|
|
|
def get_raw(self) -> dict: |
|
return self.__data.copy() |
|
|
|
def t2i_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("t2i_")) |
|
|
|
def i2i_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("i2i_")) |
|
|
|
def ip_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("ip_")) |
|
|
|
def cnc_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("cnc_")) |
|
|
|
def cnp_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("cnp_")) |
|
|
|
def cns_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("cns_")) |
|
|
|
def cnl_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("cnl_")) |
|
|
|
def cnt_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("cnt_")) |
|
|
|
def high_res_kwargs(self) -> dict: |
|
return dict(self.__get_kwargs("hrf_")) |
|
|
|
def __get_kwargs(self, prefix: str): |
|
for k, v in self.__data.items(): |
|
if k.startswith(prefix): |
|
yield k[len(prefix) :], v |
|
|
|
@property |
|
@lru_cache(1) |
|
def PROMPT(self): |
|
class PromptMethods: |
|
def __init__(self, task: Task): |
|
self.__task = task |
|
|
|
def has_placeholder_blip_merge(self) -> bool: |
|
return "<blip:[merge]>" in self.__task.get_prompt() |
|
|
|
def merge_blip(self, text: str) -> str: |
|
return self.__task.get_prompt().replace("<blip:[merge]>", text) |
|
|
|
return PromptMethods(self) |
|
|