aai / old2 /modules /events /sdxl_events.py
barreloflube's picture
Refactor flux_helpers.py to enable or disable Vae
37112ef
raw
history blame
6.32 kB
import spaces
import gradio as gr
from huggingface_hub import ModelCard
from modules.helpers.common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq
from modules.helpers.sdxl_helpers import gen_img
from config import sdxl_loras
loras = sdxl_loras
# Event functions
def update_fast_generation(fast_generation):
if fast_generation:
return (
gr.update(
value=0.0
),
gr.update(
value=8
)
)
else:
return (
gr.update(
value=7.0
),
gr.update(
value=20
)
)
def add_to_enabled_loras(selected_lora, enabled_loras):
lora_data = loras
try:
selected_lora = int(selected_lora)
if 0 <= selected_lora: # is the index of the lora in the gallery
lora_info = lora_data[selected_lora]
enabled_loras.append({
"repo_id": lora_info["repo"],
"trigger_word": lora_info["trigger_word"]
})
except ValueError:
link = selected_lora.split("/")
if len(link) == 2:
model_card = ModelCard.load(selected_lora)
trigger_word = model_card.data.get("instance_prompt", "")
enabled_loras.append({
"repo_id": selected_lora,
"trigger_word": trigger_word
})
return (
gr.update( # selected_lora
value=""
),
gr.update( # custom_lora_info
value="",
visible=False
),
gr.update( # enabled_loras
value=enabled_loras
)
)
@spaces.GPU(duration=75)
def generate_image(
model, prompt, negative_prompt, fast_generation, enabled_loras, enabled_embeddings, # type: ignore
lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
img2img_image, inpaint_image, canny_image, pose_image, depth_image, scribble_image, # type: ignore
img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, scribble_strength, # type: ignore
resize_mode,
scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
image_num_inference_steps, image_clip_skip, image_guidance_scale, image_seed, # type: ignore
refiner, vae
):
try:
base_args = {
"model": model,
"prompt": prompt,
"negative_prompt": negative_prompt,
"fast_generation": fast_generation,
"loras": None,
"embeddings": None,
"resize_mode": resize_mode,
"scheduler": scheduler,
"height": image_height,
"width": image_width,
"num_images_per_prompt": image_num_images_per_prompt,
"num_inference_steps": image_num_inference_steps,
"clip_skip": image_clip_skip,
"guidance_scale": image_guidance_scale,
"seed": image_seed,
"refiner": refiner,
"vae": vae,
"controlnet_config": None,
}
base_args = BaseReq(**base_args)
if len(enabled_loras) > 0:
base_args.loras = []
for enabled_lora, slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
if enabled_lora["repo_id"]:
base_args.loras.append({
"repo_id": enabled_lora["repo_id"],
"weight": slider
})
if len(enabled_embeddings) > 0:
base_args.embeddings = enabled_embeddings
image = None
mask_image = None
strength = None
if img2img_image:
image = img2img_image
strength = float(img2img_strength)
base_args = BaseImg2ImgReq(
**base_args.__dict__,
image=image,
strength=strength
)
elif inpaint_image:
image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
mask_image = inpaint_image['layers'][0] if image else None
strength = float(inpaint_strength)
if image and mask_image:
base_args = BaseInpaintReq(
**base_args.__dict__,
image=image,
mask_image=mask_image,
strength=strength
)
elif any([canny_image, pose_image, depth_image]):
base_args.controlnet_config = ControlNetReq(
controlnets=[],
control_images=[],
controlnet_conditioning_scale=[]
)
if canny_image:
base_args.controlnet_config.controlnets.append("canny")
base_args.controlnet_config.control_images.append(canny_image)
base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
if pose_image:
base_args.controlnet_config.controlnets.append("pose")
base_args.controlnet_config.control_images.append(pose_image)
base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
if depth_image:
base_args.controlnet_config.controlnets.append("depth")
base_args.controlnet_config.control_images.append(depth_image)
base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
if scribble_image:
base_args.controlnet_config.controlnets.append("scribble")
base_args.controlnet_config.control_images.append(scribble_image)
base_args.controlnet_config.controlnet_conditioning_scale.append(float(scribble_strength))
else:
base_args = BaseReq(**base_args.__dict__)
return gr.update(
value=gen_img(base_args),
interactive=True
)
except Exception as e:
raise gr.Error(f"Error: {e}") from e