Refactor get_pipe function to handle different request types and pipeline configurations
c7554bf
import gc | |
import random | |
from typing import List, Optional | |
import torch | |
import numpy as np | |
from pydantic import BaseModel | |
from PIL import Image | |
from diffusers import ( | |
FluxPipeline, | |
FluxImg2ImgPipeline, | |
FluxInpaintPipeline, | |
FluxControlNetPipeline, | |
StableDiffusionXLPipeline, | |
StableDiffusionXLImg2ImgPipeline, | |
StableDiffusionXLInpaintPipeline, | |
StableDiffusionXLControlNetPipeline, | |
StableDiffusionXLControlNetImg2ImgPipeline, | |
StableDiffusionXLControlNetInpaintPipeline, | |
AutoPipelineForText2Image, | |
AutoPipelineForImage2Image, | |
AutoPipelineForInpainting, | |
) | |
from diffusers.schedulers import * | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from controlnet_aux.processor import Processor | |
from photomaker import ( | |
PhotoMakerStableDiffusionXLPipeline, | |
PhotoMakerStableDiffusionXLControlNetPipeline, | |
analyze_faces | |
) | |
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1 | |
from .init_sys import device, models, refiner, safety_checker, feature_extractor, controlnet_models, face_detector | |
# Models | |
class ControlNetReq(BaseModel): | |
controlnets: List[str] # ["canny", "tile", "depth"] | |
control_images: List[Image.Image] | |
controlnet_conditioning_scale: List[float] | |
class Config: | |
arbitrary_types_allowed=True | |
class SDReq(BaseModel): | |
model: str = "" | |
prompt: str = "" | |
negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev" | |
fast_generation: Optional[bool] = True | |
loras: Optional[list] = [] | |
embeddings: Optional[list] = [] | |
resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill | |
scheduler: Optional[str] = "euler_fl" | |
height: int = 1024 | |
width: int = 1024 | |
num_images_per_prompt: int = 1 | |
num_inference_steps: int = 8 | |
guidance_scale: float = 3.5 | |
seed: Optional[int] = 0 | |
refiner: bool = False | |
vae: bool = True | |
controlnet_config: Optional[ControlNetReq] = None | |
photomaker_images: Optional[List[Image.Image]] = None | |
class Config: | |
arbitrary_types_allowed=True | |
class SDImg2ImgReq(SDReq): | |
image: Image.Image | |
strength: float = 1.0 | |
class Config: | |
arbitrary_types_allowed=True | |
class SDInpaintReq(SDImg2ImgReq): | |
mask_image: Image.Image | |
class Config: | |
arbitrary_types_allowed=True | |
# Helper functions | |
def get_controlnet(controlnet_config: ControlNetReq): | |
control_mode = [] | |
controlnet = [] | |
for m in controlnet_models: | |
for c in controlnet_config.controlnets: | |
if c in m["layers"]: | |
control_mode.append(m["layers"].index(c)) | |
controlnet.append(m["controlnet"]) | |
return controlnet, control_mode | |
def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq): | |
for m in models: | |
if m["repo_id"] == request.model: | |
pipeline = m['pipeline'] | |
controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None) | |
pipe_args = { | |
"pipeline": pipeline, | |
"control_mode": control_mode, | |
} | |
if request.controlnet_config: | |
pipe_args["controlnet"] = controlnet | |
if not request.photomaker_images: | |
if isinstance(request, SDReq): | |
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) | |
elif isinstance(request, SDImg2ImgReq): | |
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) | |
elif isinstance(request, SDInpaintReq): | |
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) | |
else: | |
raise ValueError(f"Unknown request type: {type(request)}") | |
elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])): | |
if request.controlnet_config: | |
pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args) | |
else: | |
pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args) | |
else: | |
raise ValueError(f"Invalid request type: {type(request)}") | |
return pipe_args | |
def load_scheduler(pipeline, scheduler): | |
schedulers = { | |
"dpmpp_2m": (DPMSolverMultistepScheduler, {}), | |
"dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), | |
"dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}), | |
"dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}), | |
"dpmpp_sde": (DPMSolverSinglestepScheduler, {}), | |
"dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), | |
"dpm2": (KDPM2DiscreteScheduler, {}), | |
"dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), | |
"dpm2_a": (KDPM2AncestralDiscreteScheduler, {}), | |
"dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), | |
"euler": (EulerDiscreteScheduler, {}), | |
"euler_a": (EulerAncestralDiscreteScheduler, {}), | |
"heun": (HeunDiscreteScheduler, {}), | |
"lms": (LMSDiscreteScheduler, {}), | |
"lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), | |
"deis": (DEISMultistepScheduler, {}), | |
"unipc": (UniPCMultistepScheduler, {}), | |
"fm_euler": (FlowMatchEulerDiscreteScheduler, {}), | |
} | |
scheduler_class, kwargs = schedulers.get(scheduler, (None, {})) | |
if scheduler_class is not None: | |
scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs) | |
else: | |
raise ValueError(f"Unknown scheduler: {scheduler}") | |
return scheduler | |
def load_loras(pipeline, loras, fast_generation): | |
for i, lora in enumerate(loras): | |
pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}") | |
adapter_names = [f"lora_{i}" for i in range(len(loras))] | |
adapter_weights = [lora['weight'] for lora in loras] | |
if fast_generation: | |
hyper_lora = hf_hub_download( | |
"ByteDance/Hyper-SD", | |
"Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors" | |
) | |
hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0 | |
pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora") | |
adapter_names.append("hyper_lora") | |
adapter_weights.append(hyper_weight) | |
pipeline.set_adapters(adapter_names, adapter_weights) | |
def load_xl_embeddings(pipeline, embeddings): | |
for embedding in embeddings: | |
state_dict = load_file(hf_hub_download(embedding['repo_id'])) | |
pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) | |
pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) | |
def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): | |
for image in images: | |
if resize_mode == "resize_only": | |
image = image.resize((width, height)) | |
elif resize_mode == "crop_and_resize": | |
image = image.crop((0, 0, width, height)) | |
elif resize_mode == "resize_and_fill": | |
image = image.resize((width, height), Image.Resampling.LANCZOS) | |
return images | |
def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str): | |
response_images = [] | |
control_images = resize_images(control_images, height, width, resize_mode) | |
for controlnet, image in zip(controlnets, control_images): | |
if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl": | |
processor = Processor('canny') | |
elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl": | |
processor = Processor('depth_midas') | |
elif controlnet == "pose" or controlnet == "pose_fl": | |
processor = Processor('openpose_full') | |
elif controlnet == "scribble": | |
processor = Processor('scribble') | |
else: | |
raise ValueError(f"Invalid Controlnet: {controlnet}") | |
response_images.append(processor(image, to_pil=True)) | |
return response_images | |
def check_image_safety(images: List[Image.Image]): | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") | |
has_nsfw_concepts = safety_checker( | |
images=[images], | |
clip_input=safety_checker_input.pixel_values.to("cuda"), | |
) | |
return has_nsfw_concepts[1] | |
def get_prompt_attention(pipeline, prompt, negative_prompt): | |
if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)): | |
prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt) | |
return prompt_embeds, None, pooled_prompt_embeds, None | |
elif isinstance(pipeline, StableDiffusionXLPipeline): | |
prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt) | |
return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | |
else: | |
raise ValueError(f"Invalid pipeline type: {type(pipeline)}") | |
def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str): | |
image_input_ids = [] | |
image_id_embeds = [] | |
photomaker_images = resize_images(photomaker_images, height, width, resize_mode) | |
for image in photomaker_images: | |
image_input_ids.append(img) | |
img = np.array(image)[:, :, ::-1] | |
faces = analyze_faces(face_detector, image) | |
if len(faces) > 0: | |
image_id_embeds.append(torch.from_numpy(faces[0]['embeddings'])) | |
else: | |
raise ValueError("No face detected in the image") | |
return image_input_ids, image_id_embeds | |
def cleanup(pipeline, loras = None, embeddings = None): | |
if loras: | |
pipeline.disable_lora() | |
pipeline.unload_lora_weights() | |
if embeddings: | |
pipeline.unload_textual_inversion() | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Gen function | |
def gen_img( | |
request: SDReq | SDImg2ImgReq | SDInpaintReq | |
): | |
pipeline_args = get_pipe(request) | |
pipeline = pipeline_args['pipeline'] | |
try: | |
pipeline.scheduler = load_scheduler(pipeline, request.scheduler) | |
load_loras(pipeline, request.loras, request.fast_generation) | |
load_xl_embeddings(pipeline, request.embeddings) | |
control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None | |
photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None) | |
positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt) | |
# Common args | |
args = { | |
'prompt_embeds': positive_prompt_embeds, | |
'pooled_prompt_embeds': positive_prompt_pooled, | |
'height': request.height, | |
'width': request.width, | |
'num_images_per_prompt': request.num_images_per_prompt, | |
'num_inference_steps': request.num_inference_steps, | |
'guidance_scale': request.guidance_scale, | |
'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], | |
} | |
if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, | |
StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])): | |
args['clip_skip'] = request.clip_skip | |
args['negative_prompt_embeds'] = negative_prompt_embeds | |
args['negative_pooled_prompt_embeds'] = negative_prompt_pooled | |
if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config: | |
args['control_mode'] = pipeline_args['control_mode'] | |
args['control_image'] = control_images | |
args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale | |
if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config: | |
args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale | |
if isinstance(request, SDReq): | |
args['image'] = control_images | |
elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)): | |
args['control_image'] = control_images | |
if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])): | |
args['input_id_images'] = photomaker_images | |
args['input_id_embeds'] = photomaker_id_embeds | |
args['start_merge_step'] = 10 | |
if isinstance(request, SDImg2ImgReq): | |
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode) | |
args['strength'] = request.strength | |
elif isinstance(request, SDInpaintReq): | |
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode) | |
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode) | |
args['strength'] = request.strength | |
images = pipeline(**args).images | |
if request.refiner: | |
images = refiner( | |
prompt=request.prompt, | |
num_inference_steps=40, | |
denoising_start=0.7, | |
image=images.images | |
).images | |
cleanup(pipeline, request.loras, request.embeddings) | |
return images | |
except Exception as e: | |
cleanup(pipeline, request.loras, request.embeddings) | |
raise ValueError(f"Error generating image: {e}") from e | |