jayparmr's picture
Upload folder using huggingface_hub
86248f3
raw
history blame
3.76 kB
from enum import Enum
from typing import Union
import numpy as np
class TaskType(Enum):
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
POSE = "POSE"
CANNY = "CANNY"
REMOVE_BG = "REMOVE_BG"
INPAINT = "INPAINT"
UPSCALE_IMAGE = "UPSCALE_IMAGE"
TILE_UPSCALE = "TILE_UPSCALE"
OBJECT_REMOVAL = "OBJECT_REMOVAL"
SCRIBBLE = "SCRIBBLE"
LINEARART = "LINEARART"
class ModelType(Enum):
REAL = 10000
ANIME = 10001
COMIC = 10002
class Task:
def __init__(self, data):
self.__data = data
if data.get("seed", -1) == None or self.get_seed() == -1:
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
prompt = data.get("prompt", "")
if prompt is None:
self.__data["prompt"] = ""
else:
self.__data["prompt"] = data.get("prompt", "")[:200]
def get_taskId(self) -> str:
return self.__data.get("task_id")
def get_sourceId(self) -> str:
return self.__data.get("source_id")
def get_imageUrl(self) -> str:
return self.__data.get("imageUrl", None)
def get_prompt(self) -> str:
return self.__data.get("prompt", "")
def get_prompt_left(self) -> str:
return self.__data.get("prompt_left", "")
def get_prompt_right(self) -> str:
return self.__data.get("prompt_right", "")
def get_userId(self) -> str:
return self.__data.get("userId", "")
def get_email(self) -> str:
return self.__data.get("email", "")
def get_style(self) -> str:
return self.__data.get("style", None)
def get_iteration(self) -> float:
return float(self.__data.get("iteration", 3.0))
def get_modelType(self) -> ModelType:
id = self.get_model_id()
return ModelType(id)
def get_model_id(self) -> int:
return int(self.__data.get("modelId", 10000))
def get_width(self) -> int:
return int(self.__data.get("width", 512))
def get_height(self) -> int:
return int(self.__data.get("height", 512))
def get_seed(self) -> int:
return int(self.__data.get("seed", -1))
def get_steps(self) -> int:
return int(self.__data.get("steps", "75"))
def get_type(self) -> Union[TaskType, None]:
try:
return TaskType(self.__data.get("task_type"))
except ValueError:
return None
def get_maskImageUrl(self) -> str:
return self.__data.get("maskImageUrl")
def get_negative_prompt(self) -> str:
return self.__data.get("negative_prompt", "")
def is_prompt_engineering(self) -> bool:
return self.__data.get("auto_mode", True)
def get_queue_name(self) -> str:
return self.__data.get("queue_name", "")
def get_resize_dimension(self) -> int:
return self.__data.get("resize_dimension", 1024)
def get_ti_guidance_scale(self) -> float:
return self.__data.get("ti_guidance_scale", 7.5)
def get_i2i_guidance_scale(self) -> float:
return self.__data.get("i2i_guidance_scale", 7.5)
def get_i2i_strength(self) -> float:
return self.__data.get("i2i_strength", 0.75)
def get_cy_guidance_scale(self) -> float:
return self.__data.get("cy_guidance_scale", 9)
def get_po_guidance_scale(self) -> float:
return self.__data.get("po_guidance_scale", 7.5)
def get_nsfw_threshold(self) -> float:
return self.__data.get("nsfw_threshold", 0.03)
def can_access_nsfw(self) -> bool:
return self.__data.get("can_access_nsfw", False)
def get_access_token(self) -> str:
return self.__data.get("access_token", "")
def get_raw(self) -> dict:
return self.__data.copy()