aai / tabs /images /models.py
barreloflube's picture
Refactor field validation in models.py
0f49e30
raw
history blame
2.52 kB
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, field_validator
from PIL import Image
from config import Config as appConfig
class ControlNetReq(BaseModel):
controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
control_images: List[Image.Image]
controlnet_conditioning_scale: List[float]
class Config:
arbitrary_types_allowed = True
class BaseReq(BaseModel):
model: str = ""
prompt: str = ""
negative_prompt: Optional[str] = None
fast_generation: Optional[bool] = True
loras: Optional[list] = []
embeddings: Optional[list] = None
resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
scheduler: Optional[str] = "euler_fl"
height: int = 1024
width: int = 1024
num_images_per_prompt: int = 1
num_inference_steps: int = 8
clip_skip: Optional[int] = None
guidance_scale: float = 3.5
seed: Optional[int] = 0
refiner: bool = False
vae: bool = True
controlnet_config: Optional[ControlNetReq] = None
custom_addons: Optional[Dict[Any, Any]] = None
class Config:
arbitrary_types_allowed = True
# @field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config', mode='before')
# def check_model(cls, values):
# for m in appConfig.IMAGES_MODELS:
# if isinstance(m, dict) and m.get('repo_id') == values.get('model'):
# loader = m.get('loader')
# if loader == "flux" and values.get('negative_prompt'):
# raise ValueError("Negative prompt is not supported for Flux models.")
# if loader == "flux" and values.get('embeddings'):
# raise ValueError("Embeddings are not supported for Flux models.")
# if loader == "flux" and values.get('clip_skip'):
# raise ValueError("Clip skip is not supported for Flux models.")
# if loader == "flux" and values.get('controlnet_config'):
# if "scribble" in values.get('controlnet_config').controlnets:
# raise ValueError("Scribble is not supported for Flux models.")
# return values
class BaseImg2ImgReq(BaseReq):
image: Image.Image
strength: float = 1.0
class Config:
arbitrary_types_allowed = True
class BaseInpaintReq(BaseImg2ImgReq):
mask_image: Image.Image
class Config:
arbitrary_types_allowed = True