jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
6.17 kB
from enum import Enum
from functools import lru_cache
from typing import Optional, Union
import numpy as np
class TaskType(Enum):
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
POSE = "POSE"
CANNY = "CANNY"
REMOVE_BG = "REMOVE_BG"
INPAINT = "INPAINT"
UPSCALE_IMAGE = "UPSCALE_IMAGE"
TILE_UPSCALE = "TILE_UPSCALE"
OBJECT_REMOVAL = "OBJECT_REMOVAL"
SCRIBBLE = "SCRIBBLE"
LINEARART = "LINEARART"
REPLACE_BG = "REPLACE_BG"
RT_DRAW_SEG = "RT_DRAW_SEG"
RT_DRAW_IMG = "RT_DRAW_IMG"
PRELOAD_MODEL = "PRELOAD_MODEL"
CUSTOM_ACTION = "CUSTOM_ACTION"
SYSTEM_CMD = "SYSTEM_CMD"
OUTPAINT = "OUTPAINT"
class ModelType(Enum):
REAL = 10000
ANIME = 10001
COMIC = 10002
@classmethod
def _missing_(cls, value):
return cls.REAL
class Task:
def __init__(self, data):
self.__data = data
if data.get("seed", -1) == None or self.get_seed() == -1:
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max)
prompt = data.get("prompt", "")
if prompt is None:
self.__data["prompt"] = ""
elif len(prompt) > 200:
self.__data["prompt"] = data.get("prompt", "")[:200] + ", "
def get_taskId(self) -> str:
return self.__data.get("task_id")
def get_sourceId(self) -> str:
return self.__data.get("source_id")
def get_imageUrl(self) -> str:
return self.__data.get("imageUrl", None)
def get_auxilary_imageUrl(self) -> Optional[str]:
return self.__data.get("aux_imageUrl", None)
def get_prompt(self) -> str:
return self.__data.get("prompt", "")
def get_prompt_left(self) -> str:
return self.__data.get("prompt_left", "")
def get_prompt_right(self) -> str:
return self.__data.get("prompt_right", "")
def get_userId(self) -> str:
return self.__data.get("userId", "")
def get_email(self) -> str:
return self.__data.get("email", "")
def get_style(self) -> str:
return self.__data.get("style", None)
def get_iteration(self) -> float:
return float(self.__data.get("iteration", 3.0))
def get_modelType(self) -> ModelType:
id = self.get_model_id()
return ModelType(id)
def get_model_id(self) -> int:
return int(self.__data.get("modelId", 10000))
def get_width(self) -> int:
return int(self.__data.get("width", 512))
def get_height(self) -> int:
return int(self.__data.get("height", 512))
def get_seed(self) -> int:
return int(self.__data.get("seed", -1))
def get_steps(self) -> int:
return int(self.__data.get("steps", 30))
def get_type(self) -> Union[TaskType, None]:
try:
return TaskType(self.__data.get("task_type"))
except ValueError:
return None
def get_maskImageUrl(self) -> str:
return self.__data.get("maskImageUrl")
def get_pose_coordinates(self) -> dict:
return self.__data.get("pose_coordinates", None)
def get_pose_estimation(self) -> bool:
return self.__data.get("pose_estimation", True)
def get_negative_prompt(self) -> str:
return self.__data.get("negative_prompt", "")
def is_prompt_engineering(self) -> bool:
return self.__data.get("auto_mode", True)
def get_image_scale(self) -> float:
return self.__data.get("image_scale", 1.0)
def get_queue_name(self) -> str:
return self.__data.get("queue_name", "")
def get_resize_dimension(self) -> int:
return self.__data.get("resize_dimension", 1024)
def get_face_enhance(self) -> bool:
return self.__data.get("up_face_enhance", False)
def rbg_controlnet_conditioning_scale(self) -> float:
return self.__data.get("rbg_conditioning_scale", 0.5)
def rbg_extend_object(self) -> bool:
return self.__data.get("rbg_extend_object", False)
def get_nsfw_threshold(self) -> float:
return self.__data.get("nsfw_threshold", 0.03)
def get_num_return_sequences(self) -> int:
return self.__data.get("num_return_sequences", 4)
def can_access_nsfw(self) -> bool:
return self.__data.get("can_access_nsfw", False)
def get_access_token(self) -> str:
return self.__data.get("access_token", "")
def get_high_res_fix(self) -> bool:
return self.__data.get("high_res_fix", False)
def get_base_dimension(self):
return self.__data.get("base_dimension", None)
def get_action_data(self) -> dict:
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
return self.__data.get("action_data", {})
def get_raw(self) -> dict:
return self.__data.copy()
def t2i_kwargs(self) -> dict:
return dict(self.__get_kwargs("t2i_"))
def i2i_kwargs(self) -> dict:
return dict(self.__get_kwargs("i2i_"))
def ip_kwargs(self) -> dict:
return dict(self.__get_kwargs("ip_"))
def cnc_kwargs(self) -> dict:
return dict(self.__get_kwargs("cnc_"))
def cnp_kwargs(self) -> dict:
return dict(self.__get_kwargs("cnp_"))
def cns_kwargs(self) -> dict:
return dict(self.__get_kwargs("cns_"))
def cnl_kwargs(self) -> dict:
return dict(self.__get_kwargs("cnl_"))
def cnt_kwargs(self) -> dict:
return dict(self.__get_kwargs("cnt_"))
def high_res_kwargs(self) -> dict:
return dict(self.__get_kwargs("hrf_"))
def __get_kwargs(self, prefix: str):
for k, v in self.__data.items():
if k.startswith(prefix):
yield k[len(prefix) :], v
@property
@lru_cache(1)
def PROMPT(self):
class PromptMethods:
def __init__(self, task: Task):
self.__task = task
def has_placeholder_blip_merge(self) -> bool:
return "<blip:[merge]>" in self.__task.get_prompt()
def merge_blip(self, text: str) -> str:
return self.__task.get_prompt().replace("<blip:[merge]>", text)
return PromptMethods(self)