|
from typing import List, Optional, Dict, Any |
|
|
|
from pydantic import BaseModel, field_validator |
|
from PIL import Image |
|
|
|
from config import Config as appConfig |
|
|
|
|
|
class ControlNetReq(BaseModel): |
|
controlnets: List[str] |
|
control_images: List[Image.Image] |
|
controlnet_conditioning_scale: List[float] |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class BaseReq(BaseModel): |
|
model: str = "" |
|
prompt: str = "" |
|
negative_prompt: Optional[str] = None |
|
fast_generation: Optional[bool] = True |
|
loras: Optional[list] = [] |
|
embeddings: Optional[list] = None |
|
resize_mode: Optional[str] = "resize_and_fill" |
|
scheduler: Optional[str] = "euler_fl" |
|
height: int = 1024 |
|
width: int = 1024 |
|
num_images_per_prompt: int = 1 |
|
num_inference_steps: int = 8 |
|
clip_skip: Optional[int] = None |
|
guidance_scale: float = 3.5 |
|
seed: Optional[int] = 0 |
|
refiner: bool = False |
|
vae: bool = True |
|
controlnet_config: Optional[ControlNetReq] = None |
|
custom_addons: Optional[Dict[Any, Any]] = None |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
@field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config') |
|
def check_model(cls, values): |
|
for m in appConfig.IMAGES_MODELS: |
|
if m.get('repo_id') == values.get('model'): |
|
loader = m.get('loader') |
|
|
|
if loader == "flux" and values.get('negative_prompt'): |
|
raise ValueError("Negative prompt is not supported for Flux models.") |
|
if loader == "flux" and values.get('embeddings'): |
|
raise ValueError("Embeddings are not supported for Flux models.") |
|
if loader == "flux" and values.get('clip_skip'): |
|
raise ValueError("Clip skip is not supported for Flux models.") |
|
if loader == "flux" and values.get('controlnet_config'): |
|
if "scribble" in values.get('controlnet_config').controlnets: |
|
raise ValueError("Scribble is not supported for Flux models.") |
|
return values |
|
|
|
|
|
class BaseImg2ImgReq(BaseReq): |
|
image: Image.Image |
|
strength: float = 1.0 |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|
|
|
|
class BaseInpaintReq(BaseImg2ImgReq): |
|
mask_image: Image.Image |
|
|
|
class Config: |
|
arbitrary_types_allowed=True |
|
|