|
from enum import Enum |
|
from typing import Union |
|
|
|
import numpy as np |
|
|
|
|
|
class TaskType(Enum): |
|
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE" |
|
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE" |
|
POSE = "POSE" |
|
CANNY = "CANNY" |
|
REMOVE_BG = "REMOVE_BG" |
|
INPAINT = "INPAINT" |
|
UPSCALE_IMAGE = "UPSCALE_IMAGE" |
|
TILE_UPSCALE = "TILE_UPSCALE" |
|
OBJECT_REMOVAL = "OBJECT_REMOVAL" |
|
SCRIBBLE = "SCRIBBLE" |
|
LINEARART = "LINEARART" |
|
|
|
|
|
class ModelType(Enum): |
|
REAL = 10000 |
|
ANIME = 10001 |
|
COMIC = 10002 |
|
|
|
|
|
class Task: |
|
def __init__(self, data): |
|
self.__data = data |
|
if data.get("seed", -1) == None or self.get_seed() == -1: |
|
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max) |
|
prompt = data.get("prompt", "") |
|
if prompt is None: |
|
self.__data["prompt"] = "" |
|
else: |
|
self.__data["prompt"] = data.get("prompt", "")[:200] |
|
|
|
def get_taskId(self) -> str: |
|
return self.__data.get("task_id") |
|
|
|
def get_sourceId(self) -> str: |
|
return self.__data.get("source_id") |
|
|
|
def get_imageUrl(self) -> str: |
|
return self.__data.get("imageUrl", None) |
|
|
|
def get_prompt(self) -> str: |
|
return self.__data.get("prompt", "") |
|
|
|
def get_prompt_left(self) -> str: |
|
return self.__data.get("prompt_left", "") |
|
|
|
def get_prompt_right(self) -> str: |
|
return self.__data.get("prompt_right", "") |
|
|
|
def get_userId(self) -> str: |
|
return self.__data.get("userId", "") |
|
|
|
def get_email(self) -> str: |
|
return self.__data.get("email", "") |
|
|
|
def get_style(self) -> str: |
|
return self.__data.get("style", None) |
|
|
|
def get_iteration(self) -> float: |
|
return float(self.__data.get("iteration", 3.0)) |
|
|
|
def get_modelType(self) -> ModelType: |
|
id = self.get_model_id() |
|
return ModelType(id) |
|
|
|
def get_model_id(self) -> int: |
|
return int(self.__data.get("modelId", 10000)) |
|
|
|
def get_width(self) -> int: |
|
return int(self.__data.get("width", 512)) |
|
|
|
def get_height(self) -> int: |
|
return int(self.__data.get("height", 512)) |
|
|
|
def get_seed(self) -> int: |
|
return int(self.__data.get("seed", -1)) |
|
|
|
def get_steps(self) -> int: |
|
return int(self.__data.get("steps", "75")) |
|
|
|
def get_type(self) -> Union[TaskType, None]: |
|
try: |
|
return TaskType(self.__data.get("task_type")) |
|
except ValueError: |
|
return None |
|
|
|
def get_maskImageUrl(self) -> str: |
|
return self.__data.get("maskImageUrl") |
|
|
|
def get_negative_prompt(self) -> str: |
|
return self.__data.get("negative_prompt", "") |
|
|
|
def is_prompt_engineering(self) -> bool: |
|
return self.__data.get("auto_mode", True) |
|
|
|
def get_queue_name(self) -> str: |
|
return self.__data.get("queue_name", "") |
|
|
|
def get_resize_dimension(self) -> int: |
|
return self.__data.get("resize_dimension", 1024) |
|
|
|
def get_ti_guidance_scale(self) -> float: |
|
return self.__data.get("ti_guidance_scale", 7.5) |
|
|
|
def get_i2i_guidance_scale(self) -> float: |
|
return self.__data.get("i2i_guidance_scale", 7.5) |
|
|
|
def get_i2i_strength(self) -> float: |
|
return self.__data.get("i2i_strength", 0.75) |
|
|
|
def get_cy_guidance_scale(self) -> float: |
|
return self.__data.get("cy_guidance_scale", 9) |
|
|
|
def get_po_guidance_scale(self) -> float: |
|
return self.__data.get("po_guidance_scale", 7.5) |
|
|
|
def get_nsfw_threshold(self) -> float: |
|
return self.__data.get("nsfw_threshold", 0.03) |
|
|
|
def can_access_nsfw(self) -> bool: |
|
return self.__data.get("can_access_nsfw", False) |
|
|
|
def get_access_token(self) -> str: |
|
return self.__data.get("access_token", "") |
|
|
|
def get_raw(self) -> dict: |
|
return self.__data.copy() |
|
|