aai / tabs /images /models.py
barreloflube's picture
Refactor flux_helpers.py to enable or disable Vae
37112ef
raw
history blame
2.4 kB
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, field_validator
from PIL import Image
from config import Config as appConfig
class ControlNetReq(BaseModel):
controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
control_images: List[Image.Image]
controlnet_conditioning_scale: List[float]
class Config:
arbitrary_types_allowed=True
class BaseReq(BaseModel):
model: str = ""
prompt: str = ""
negative_prompt: Optional[str] = None
fast_generation: Optional[bool] = True
loras: Optional[list] = []
embeddings: Optional[list] = None
resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
scheduler: Optional[str] = "euler_fl"
height: int = 1024
width: int = 1024
num_images_per_prompt: int = 1
num_inference_steps: int = 8
clip_skip: Optional[int] = None
guidance_scale: float = 3.5
seed: Optional[int] = 0
refiner: bool = False
vae: bool = True
controlnet_config: Optional[ControlNetReq] = None
custom_addons: Optional[Dict[Any, Any]] = None
class Config:
arbitrary_types_allowed=True
@field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config')
def check_model(cls, values):
for m in appConfig.IMAGES_MODELS:
if m.get('repo_id') == values.get('model'):
loader = m.get('loader')
if loader == "flux" and values.get('negative_prompt'):
raise ValueError("Negative prompt is not supported for Flux models.")
if loader == "flux" and values.get('embeddings'):
raise ValueError("Embeddings are not supported for Flux models.")
if loader == "flux" and values.get('clip_skip'):
raise ValueError("Clip skip is not supported for Flux models.")
if loader == "flux" and values.get('controlnet_config'):
if "scribble" in values.get('controlnet_config').controlnets:
raise ValueError("Scribble is not supported for Flux models.")
return values
class BaseImg2ImgReq(BaseReq):
image: Image.Image
strength: float = 1.0
class Config:
arbitrary_types_allowed=True
class BaseInpaintReq(BaseImg2ImgReq):
mask_image: Image.Image
class Config:
arbitrary_types_allowed=True