|
from enum import Enum |
|
from functools import lru_cache |
|
from typing import Union |
|
|
|
import numpy as np |
|
|
|
|
|
class TaskType(Enum): |
|
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE" |
|
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE" |
|
POSE = "POSE" |
|
CANNY = "CANNY" |
|
REMOVE_BG = "REMOVE_BG" |
|
INPAINT = "INPAINT" |
|
UPSCALE_IMAGE = "UPSCALE_IMAGE" |
|
TILE_UPSCALE = "TILE_UPSCALE" |
|
OBJECT_REMOVAL = "OBJECT_REMOVAL" |
|
SCRIBBLE = "SCRIBBLE" |
|
LINEARART = "LINEARART" |
|
REPLACE_BG = "REPLACE_BG" |
|
SYSTEM_CMD = "SYSTEM_CMD" |
|
|
|
|
|
class ModelType(Enum): |
|
REAL = 10000 |
|
ANIME = 10001 |
|
COMIC = 10002 |
|
|
|
|
|
class Task: |
|
def __init__(self, data): |
|
self.__data = data |
|
if data.get("seed", -1) == None or self.get_seed() == -1: |
|
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int32).max) |
|
prompt = data.get("prompt", "") |
|
if prompt is None: |
|
self.__data["prompt"] = "" |
|
elif len(prompt) > 200: |
|
self.__data["prompt"] = data.get("prompt", "")[:200] + ", " |
|
|
|
def get_taskId(self) -> str: |
|
return self.__data.get("task_id") |
|
|
|
def get_sourceId(self) -> str: |
|
return self.__data.get("source_id") |
|
|
|
def get_imageUrl(self) -> str: |
|
return self.__data.get("imageUrl", None) |
|
|
|
def get_prompt(self) -> str: |
|
return self.__data.get("prompt", "") |
|
|
|
def get_prompt_left(self) -> str: |
|
return self.__data.get("prompt_left", "") |
|
|
|
def get_prompt_right(self) -> str: |
|
return self.__data.get("prompt_right", "") |
|
|
|
def get_userId(self) -> str: |
|
return self.__data.get("userId", "") |
|
|
|
def get_email(self) -> str: |
|
return self.__data.get("email", "") |
|
|
|
def get_style(self) -> str: |
|
return self.__data.get("style", None) |
|
|
|
def get_iteration(self) -> float: |
|
return float(self.__data.get("iteration", 3.0)) |
|
|
|
def get_modelType(self) -> ModelType: |
|
id = self.get_model_id() |
|
return ModelType(id) |
|
|
|
def get_model_id(self) -> int: |
|
return int(self.__data.get("modelId", 10000)) |
|
|
|
def get_width(self) -> int: |
|
return int(self.__data.get("width", 512)) |
|
|
|
def get_height(self) -> int: |
|
return int(self.__data.get("height", 512)) |
|
|
|
def get_seed(self) -> int: |
|
return int(self.__data.get("seed", -1)) |
|
|
|
def get_steps(self) -> int: |
|
return int(self.__data.get("steps", "75")) |
|
|
|
def get_type(self) -> Union[TaskType, None]: |
|
try: |
|
return TaskType(self.__data.get("task_type")) |
|
except ValueError: |
|
return None |
|
|
|
def get_maskImageUrl(self) -> str: |
|
return self.__data.get("maskImageUrl") |
|
|
|
def get_pose_coordinates(self) -> dict: |
|
return self.__data.get("pose_coordinates", None) |
|
|
|
def get_pose_estimation(self) -> bool: |
|
return self.__data.get("pose_estimation", True) |
|
|
|
def get_negative_prompt(self) -> str: |
|
return self.__data.get("negative_prompt", "") |
|
|
|
def is_prompt_engineering(self) -> bool: |
|
return self.__data.get("auto_mode", True) |
|
|
|
def get_image_scale(self) -> float: |
|
return self.__data.get("image_scale", 1.0) |
|
|
|
def get_queue_name(self) -> str: |
|
return self.__data.get("queue_name", "") |
|
|
|
def get_resize_dimension(self) -> int: |
|
return self.__data.get("resize_dimension", 1024) |
|
|
|
def get_face_enhance(self) -> bool: |
|
return self.__data.get("up_face_enhance", False) |
|
|
|
def get_ti_guidance_scale(self) -> float: |
|
return self.__data.get("ti_guidance_scale", 7.5) |
|
|
|
def get_i2i_guidance_scale(self) -> float: |
|
return self.__data.get("i2i_guidance_scale", 7.5) |
|
|
|
def get_i2i_strength(self) -> float: |
|
return self.__data.get("i2i_strength", 0.75) |
|
|
|
def get_cy_guidance_scale(self) -> float: |
|
return self.__data.get("cy_guidance_scale", 9) |
|
|
|
def get_po_guidance_scale(self) -> float: |
|
return self.__data.get("po_guidance_scale", 7.5) |
|
|
|
def rbg_controlnet_conditioning_scale(self) -> float: |
|
return self.__data.get("rbg_conditioning_scale", 0.5) |
|
|
|
def get_nsfw_threshold(self) -> float: |
|
return self.__data.get("nsfw_threshold", 0.03) |
|
|
|
def can_access_nsfw(self) -> bool: |
|
return self.__data.get("can_access_nsfw", False) |
|
|
|
def get_access_token(self) -> str: |
|
return self.__data.get("access_token", "") |
|
|
|
def get_high_res_fix(self) -> bool: |
|
return self.__data.get("high_res_fix", False) |
|
|
|
def get_raw(self) -> dict: |
|
return self.__data.copy() |
|
|
|
@property |
|
@lru_cache(1) |
|
def PROMPT(self): |
|
class PromptMethods: |
|
def __init__(self, task: Task): |
|
self.__task = task |
|
|
|
def has_placeholder_blip_merge(self) -> bool: |
|
return "<blip:[merge]>" in self.__task.get_prompt() |
|
|
|
def merge_blip(self, text: str) -> str: |
|
return self.__task.get_prompt().replace("<blip:[merge]>", text) |
|
|
|
return PromptMethods(self) |
|
|