import gc import random import gradio as gr import torch from controlnet_aux.processor import Processor from safetensors.torch import load_file from diffusers import ( AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, ) from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_sdxl from huggingface_hub import hf_hub_download from diffusers.schedulers import * from .models import * from .load_models import device, models, flux_vae, sdxl_vae, refiner, controlnets sd_pipes = (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline) flux_pipes = (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline) def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): for model in models: if model['repo_id'] == request.model: pipe_args = { "pipeline": model['pipeline'], } # Set ControlNet config if request.controlnet_config: pipe_args["controlnet"] = [] if model['loader'] == 'sdxl' or model['loader'] == 'flux': for controlnet in controlnets: if request.controlnet_config.controlnet in controlnet['layers']: pipe_args["controlnet"].append(controlnet['controlnet']) elif model['loader'] == 'flux-multi': controlnet = next((controlnet for controlnet in controlnets if controlnet['loader'] == 'flux-multi'), None) if controlnet is not None: # control_mode = list of index of layers pipe_args['control_mode'] = [controlnet['layers'].index(layer) for layer in request.controlnet_config.controlnet] pipe_args['controlnet'].append(controlnet['controlnet']) # Choose Pipeline Mode if not request.custom_addons: if isinstance(request, BaseInpaintReq): pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) elif isinstance(request, BaseImg2ImgReq): pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) elif isinstance(request, BaseReq): pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) elif request.custom_addons: pipe_args['pipeline'] = None # Enable or Disable Vae if request.vae: pipe_args["pipeline"].vae = sdxl_vae if model['loader'] == 'sdxl' else flux_vae elif not request.vae: pipe_args["pipeline"].vae = None if model['loader'] == 'sdxl' else flux_vae # Set Scheduler pipe_args["pipeline"].scheduler = load_scheduler(pipe_args["pipeline"], request.scheduler) # Set Loras if request.loras: for i, lora in enumerate(request.loras): pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}") adapter_names = [f"lora_{i}" for i in range(len(request.loras))] adapter_weights = [lora['weight'] for lora in request.loras] if request.fast_generation: hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") if model['loader'] == 'flux' \ else hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-8steps-lora.safetensors") hyper_weight = 0.125 if model['loader'] == 'flux' else 1.0 pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora") pipe_args["pipeline"].set_adapters(["hyper_lora"], [hyper_weight]) pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights) # Set Embeddings if request.embeddings and model['loader'] == 'sdxl': for embedding in request.embeddings: state_dict = load_file(hf_hub_download(embedding['repo_id'])) pipe_args["pipeline"].load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder_2, tokenizer=pipe_args["pipeline"].tokenizer_2) pipe_args["pipeline"].load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipe_args["pipeline"].text_encoder, tokenizer=pipe_args["pipeline"].tokenizer) return pipe_args def load_scheduler(pipeline, scheduler): schedulers = { "dpmpp_2m": (DPMSolverMultistepScheduler, {}), "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}), "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}), "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}), "dpmpp_sde": (DPMSolverSinglestepScheduler, {}), "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}), "dpm2": (KDPM2DiscreteScheduler, {}), "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}), "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}), "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}), "euler": (EulerDiscreteScheduler, {}), "euler_a": (EulerAncestralDiscreteScheduler, {}), "heun": (HeunDiscreteScheduler, {}), "lms": (LMSDiscreteScheduler, {}), "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), "deis": (DEISMultistepScheduler, {}), "unipc": (UniPCMultistepScheduler, {}), "fm_euler": (FlowMatchEulerDiscreteScheduler, {}), } scheduler_class, kwargs = schedulers.get(scheduler, (None, {})) if scheduler_class is not None: scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs) else: raise ValueError(f"Unknown scheduler: {scheduler}") return scheduler def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): for image in images: if resize_mode == "resize_only": image = image.resize((width, height)) elif resize_mode == "crop_and_resize": image = image.crop((0, 0, width, height)) elif resize_mode == "resize_and_fill": image = image.resize((width, height), Image.Resampling.LANCZOS) return images def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str): response_images = [] control_images = resize_images(control_images, height, width, resize_mode) for controlnet, image in zip(controlnets, control_images): if controlnet == "canny": processor = Processor('canny') elif controlnet == "depth": processor = Processor('depth_midas') elif controlnet == "pose": processor = Processor('openpose_full') elif controlnet == "scribble": processor = Processor('scribble') else: raise ValueError(f"Invalid Controlnet: {controlnet}") response_images.append(processor(image, to_pil=True)) return response_images def get_control_mode(controlnet_config: ControlNetReq): control_mode = [] for controlnet in controlnets: if controlnet['loader'] == 'flux-multi': layers = controlnet['layers'] for c in controlnet_config.controlnets: if c in layers: control_mode.append(layers.index(c)) return control_mode # def check_image_safety(images: List[Image.Image]): # safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda") # has_nsfw_concepts = safety_checker( # images=[images], # clip_input=safety_checker_input.pixel_values.to("cuda"), # ) # return has_nsfw_concepts[1] # def get_prompt_attention(pipeline, prompt, negative_prompt): # if isinstance(pipeline, flux_pipes): # prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt, device=device) # return prompt_embeds, None, pooled_prompt_embeds, None # elif isinstance(pipeline, sd_pipes): # prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt, device=device) # return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds def cleanup(pipeline, loras = None, embeddings = None): if loras: # pipeline.disable_lora() pipeline.unload_lora_weights() if embeddings: pipeline.unload_textual_inversion() gc.collect() torch.cuda.empty_cache() # Gen Function def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq, progress=gr.Progress(track_tqdm=True)): progress(0.1, "Loading Pipeline") pipeline_args = get_pipe(request) pipeline = pipeline_args["pipeline"] try: progress(0.3, "Getting Prompt Embeddings") # Get Prompt Embeddings if isinstance(pipeline, flux_pipes): positive_prompt_embeds, positive_prompt_pooled = get_weighted_text_embeddings_flux1(pipeline, request.prompt) elif isinstance(pipeline, sd_pipes): positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_weighted_text_embeddings_sdxl(pipeline, request.prompt, request.negative_prompt) progress(0.5, "Configuring Pipeline") # Common Args args = { 'prompt_embeds': positive_prompt_embeds, 'pooled_prompt_embeds': positive_prompt_pooled, 'height': request.height, 'width': request.width, 'num_images_per_prompt': request.num_images_per_prompt, 'num_inference_steps': request.num_inference_steps, 'guidance_scale': request.guidance_scale, 'generator': [torch.Generator().manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator().manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], } if isinstance(pipeline, sd_pipes): args['clip_skip'] = request.clip_skip args['negative_prompt_embeds'] = negative_prompt_embeds args['negative_pooled_prompt_embeds'] = negative_prompt_pooled if request.controlnet_config: args['control_image'] = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale if request.controlnet_config and isinstance(pipeline, flux_pipes): args['control_mode'] = get_control_mode(request.controlnet_config) if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)): args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0] args['strength'] = request.strength if isinstance(request, BaseInpaintReq): args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0] # Generate progress(0.9, "Generating Images") gr.Info(f"Request {type(request)}: {str(request.__dict__)}", duration=60) images = pipeline(**args).images # Refiner if request.refiner: images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images progress(1.0, "Cleaning Up") cleanup(pipeline, request.loras, request.embeddings) return images except Exception as e: cleanup(pipeline, request.loras, request.embeddings) raise gr.Error(f"Error: {e}")