from enum import Enum from functools import lru_cache from typing import Optional, Union import numpy as np class TaskType(Enum): TEXT_TO_IMAGE = "GENERATE_AI_IMAGE" IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE" POSE = "POSE" CANNY = "CANNY" REMOVE_BG = "REMOVE_BG" INPAINT = "INPAINT" UPSCALE_IMAGE = "UPSCALE_IMAGE" TILE_UPSCALE = "TILE_UPSCALE" OBJECT_REMOVAL = "OBJECT_REMOVAL" SCRIBBLE = "SCRIBBLE" LINEARART = "LINEARART" REPLACE_BG = "REPLACE_BG" RT_DRAW_SEG = "RT_DRAW_SEG" RT_DRAW_IMG = "RT_DRAW_IMG" PRELOAD_MODEL = "PRELOAD_MODEL" CUSTOM_ACTION = "CUSTOM_ACTION" SYSTEM_CMD = "SYSTEM_CMD" OUTPAINT = "OUTPAINT" class ModelType(Enum): REAL = 10000 ANIME = 10001 COMIC = 10002 @classmethod def _missing_(cls, value): return cls.REAL class Task: def __init__(self, data): self.__data = data if data.get("seed", -1) == None or self.get_seed() == -1: self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max) prompt = data.get("prompt", "") if prompt is None: self.__data["prompt"] = "" elif len(prompt) > 200: self.__data["prompt"] = data.get("prompt", "")[:200] + ", " def get_taskId(self) -> str: return self.__data.get("task_id") def get_sourceId(self) -> str: return self.__data.get("source_id") def get_imageUrl(self) -> str: return self.__data.get("imageUrl", None) def get_auxilary_imageUrl(self) -> Optional[str]: return self.__data.get("aux_imageUrl", None) def get_prompt(self) -> str: return self.__data.get("prompt", "") def get_prompt_left(self) -> str: return self.__data.get("prompt_left", "") def get_prompt_right(self) -> str: return self.__data.get("prompt_right", "") def get_userId(self) -> str: return self.__data.get("userId", "") def get_email(self) -> str: return self.__data.get("email", "") def get_style(self) -> str: return self.__data.get("style", None) def get_iteration(self) -> float: return float(self.__data.get("iteration", 3.0)) def get_modelType(self) -> ModelType: id = self.get_model_id() return ModelType(id) def get_model_id(self) -> int: return int(self.__data.get("modelId", 10000)) def get_width(self) -> int: return int(self.__data.get("width", 512)) def get_height(self) -> int: return int(self.__data.get("height", 512)) def get_seed(self) -> int: return int(self.__data.get("seed", -1)) def get_steps(self) -> int: return int(self.__data.get("steps", 30)) def get_type(self) -> Union[TaskType, None]: try: return TaskType(self.__data.get("task_type")) except ValueError: return None def get_maskImageUrl(self) -> str: return self.__data.get("maskImageUrl") def get_pose_coordinates(self) -> dict: return self.__data.get("pose_coordinates", None) def get_pose_estimation(self) -> bool: return self.__data.get("pose_estimation", True) def get_negative_prompt(self) -> str: return self.__data.get("negative_prompt", "") def is_prompt_engineering(self) -> bool: return self.__data.get("auto_mode", True) def get_image_scale(self) -> float: return self.__data.get("image_scale", 1.0) def get_queue_name(self) -> str: return self.__data.get("queue_name", "") def get_resize_dimension(self) -> int: return self.__data.get("resize_dimension", 1024) def get_face_enhance(self) -> bool: return self.__data.get("up_face_enhance", False) def rbg_controlnet_conditioning_scale(self) -> float: return self.__data.get("rbg_conditioning_scale", 0.5) def rbg_extend_object(self) -> bool: return self.__data.get("rbg_extend_object", False) def get_nsfw_threshold(self) -> float: return self.__data.get("nsfw_threshold", 0.03) def get_num_return_sequences(self) -> int: return self.__data.get("num_return_sequences", 4) def can_access_nsfw(self) -> bool: return self.__data.get("can_access_nsfw", False) def get_access_token(self) -> str: return self.__data.get("access_token", "") def get_high_res_fix(self) -> bool: return self.__data.get("high_res_fix", False) def get_base_dimension(self): return self.__data.get("base_dimension", None) def get_action_data(self) -> dict: "If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key" return self.__data.get("action_data", {}) def get_raw(self) -> dict: return self.__data.copy() def t2i_kwargs(self) -> dict: return dict(self.__get_kwargs("t2i_")) def i2i_kwargs(self) -> dict: return dict(self.__get_kwargs("i2i_")) def ip_kwargs(self) -> dict: return dict(self.__get_kwargs("ip_")) def cnc_kwargs(self) -> dict: return dict(self.__get_kwargs("cnc_")) def cnp_kwargs(self) -> dict: return dict(self.__get_kwargs("cnp_")) def cns_kwargs(self) -> dict: return dict(self.__get_kwargs("cns_")) def cnl_kwargs(self) -> dict: return dict(self.__get_kwargs("cnl_")) def cnt_kwargs(self) -> dict: return dict(self.__get_kwargs("cnt_")) def high_res_kwargs(self) -> dict: return dict(self.__get_kwargs("hrf_")) def __get_kwargs(self, prefix: str): for k, v in self.__data.items(): if k.startswith(prefix): yield k[len(prefix) :], v @property @lru_cache(1) def PROMPT(self): class PromptMethods: def __init__(self, task: Task): self.__task = task def has_placeholder_blip_merge(self) -> bool: return "" in self.__task.get_prompt() def merge_blip(self, text: str) -> str: return self.__task.get_prompt().replace("", text) return PromptMethods(self)