File size: 5,392 Bytes
42ae52a 3a5022f daf9c75 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a a685f13 42ae52a a685f13 07dc8e6 37112ef 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a ecfb7d9 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a b6b27f8 42ae52a b6b27f8 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a 07dc8e6 42ae52a ab735b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import random
import gradio as gr
import torch
from diffusers import (
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
)
from huggingface_hub import hf_hub_download
from diffusers.schedulers import *
from huggingface_hub import hf_hub_download
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet
from modules.pipelines.common_pipelines import refiner
def get_control_mode(controlnet_config: ControlNetReq):
control_mode = []
layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]
for c in controlnet_config.controlnets:
if c in layers:
control_mode.append(layers.index(c))
return control_mode
def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
for m in models:
if m['repo_id'] == request.model:
pipe_args = {
"pipeline": m['pipeline'],
}
# Set ControlNet config
if request.controlnet_config:
pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
pipe_args["controlnet"] = [controlnet]
# Choose Pipeline Mode
if isinstance(request, BaseInpaintReq):
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
elif isinstance(request, BaseImg2ImgReq):
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
elif isinstance(request, BaseReq):
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
# Enable or Disable Vae
if request.vae:
pipe_args["pipeline"].vae = flux_vae
elif not request.vae:
pipe_args["pipeline"].vae = None
# Set Scheduler
pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)
# Set Loras
if request.loras:
for i, lora in enumerate(request.loras):
pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
adapter_weights = [lora['weight'] for lora in request.loras]
if request.fast_generation:
hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
hyper_weight = 0.125
pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
adapter_names.append("hyper_lora")
adapter_weights.append(hyper_weight)
pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)
return pipe_args
def get_prompt_attention(pipeline, prompt):
return get_weighted_text_embeddings_flux1(pipeline, prompt)
# Gen Function
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
pipe_args = get_pipe(request)
pipeline = pipe_args["pipeline"]
try:
positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
# Common Args
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = {
'prompt_embeds': positive_prompt_embeds,
'pooled_prompt_embeds': positive_prompt_pooled,
'height': request.height,
'width': request.width,
'num_images_per_prompt': request.num_images_per_prompt,
'num_inference_steps': request.num_inference_steps,
'guidance_scale': request.guidance_scale,
'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
}
if request.controlnet_config:
args['control_mode'] = get_control_mode(request.controlnet_config)
args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale
if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
args['strength'] = request.strength
if isinstance(request, BaseInpaintReq):
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]
# Generate
images = pipeline(**args).images
# Refiner
if request.refiner:
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images
return images
except Exception as e:
cleanup(pipeline, request.loras)
raise gr.Error(f"Error: {e}")
finally:
cleanup(pipeline, request.loras)
|