#!/usr/bin/env python from __future__ import annotations import os import random import toml import gradio as gr import numpy as np import PIL.Image import torch import utils import gc from safetensors.torch import load_file import lora_diffusers from lora_diffusers import LoRANetwork, create_network_from_weights from huggingface_hub import hf_hub_download from diffusers.models import AutoencoderKL from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler DESCRIPTION = "Animagine XL" if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU 🥶 This demo does not work on CPU.
" IS_COLAB = utils.is_google_colab() MAX_SEED = np.iinfo(np.int32).max CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1" MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048")) USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1" ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1" MODEL = "Linaqruf/animagine-xl" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): pipe = DiffusionPipeline.from_pretrained( MODEL, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl.py", use_safetensors=True, variant="fp16", ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) if ENABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() else: pipe.to(device) if USE_TORCH_COMPILE: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) else: pipe = None def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def get_image_path(base_path): extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"] for ext in extensions: if os.path.exists(base_path + ext): return base_path + ext # If no match is found, return None or raise an error return None def update_selection(selected_state: gr.SelectData): lora_repo = sdxl_loras[selected_state.index]["repo"] lora_weight = sdxl_loras[selected_state.index]["multiplier"] updated_selected_info = f"{lora_repo}" updated_prompt = sdxl_loras[selected_state.index]["sample_prompt"] updated_negative = sdxl_loras[selected_state.index]["sample_negative"] return ( updated_selected_info, selected_state, lora_weight, updated_prompt, negative_presets_dict.get(updated_negative, ""), updated_negative, ) def create_network(text_encoders, unet, state_dict, multiplier, device): network = create_network_from_weights( text_encoders, unet, state_dict, multiplier=multiplier ) network.load_state_dict(state_dict) network.to(device, dtype=unet.dtype) network.apply_to(multiplier=multiplier) return network # def backup_sd(state_dict): # for k, v in state_dict.items(): # state_dict[k] = v.detach().cpu() # return state_dict def generate( prompt: str, negative_prompt: str = "", prompt_2: str = "", negative_prompt_2: str = "", use_prompt_2: bool = False, seed: int = 0, width: int = 1024, height: int = 1024, target_width: int = 1024, target_height: int = 1024, original_width: int = 4096, original_height: int = 4096, guidance_scale: float = 12.0, num_inference_steps: int = 50, use_lora: bool = False, lora_weight: float = 1.0, set_target_size: bool = False, set_original_size: bool = False, selected_state: str = "", ) -> PIL.Image.Image: generator = torch.Generator().manual_seed(seed) network = None # Initialize to None network_state = {"current_lora": None, "multiplier": None} # _unet = pipe.unet.state_dict() # backup_sd(_unet) # _text_encoder = pipe.text_encoder.state_dict() # backup_sd(_text_encoder) # _text_encoder_2 = pipe.text_encoder_2.state_dict() # backup_sd(_text_encoder_2) if not set_original_size: original_width = 4096 original_height = 4096 if not set_target_size: target_width = width target_height = height if negative_prompt == "": negative_prompt = None if not use_prompt_2: prompt_2 = None negative_prompt_2 = None if negative_prompt_2 == "": negative_prompt_2 = None if use_lora: if not selected_state: raise Exception("You must select a LoRA") repo_name = sdxl_loras[selected_state.index]["repo"] full_path_lora = saved_names[selected_state.index] weight_name = sdxl_loras[selected_state.index]["weights"] lora_sd = load_file(full_path_lora) text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if network_state["current_lora"] != repo_name: network = create_network( text_encoders, pipe.unet, lora_sd, lora_weight, device ) network_state["current_lora"] = repo_name network_state["multiplier"] = lora_weight elif network_state["multiplier"] != lora_weight: network = create_network( text_encoders, pipe.unet, lora_sd, lora_weight, device ) network_state["multiplier"] = lora_weight else: if network: network.unapply_to() network = None network_state = {"current_lora": None, "multiplier": None} try: image = pipe( prompt=prompt, negative_prompt=negative_prompt, prompt_2=prompt_2, negative_prompt_2=negative_prompt_2, width=width, height=height, target_size=(target_width, target_height), original_size=(original_width, original_height), guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, output_type="pil", ).images[0] if network: network.unapply_to() network = None return image except Exception as e: print(f"An error occurred: {e}") raise finally: # pipe.unet.load_state_dict(_unet) # pipe.text_encoder.load_state_dict(_text_encoder) # pipe.text_encoder_2.load_state_dict(_text_encoder_2) # del _unet, _text_encoder, _text_encoder_2 if network: network.unapply_to() network = None if use_lora: del lora_sd, text_encoders gc.collect() examples = [ "face focus, cute, masterpiece, best quality, 1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck", "face focus, bishounen, masterpiece, best quality, 1boy, green hair, sweater, looking at viewer, upper body, beanie, outdoors, night, turtleneck", ] negative_presets_dict = { "None": "", "Standard": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", "Weighted": "(low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn, bad image", } with open("lora.toml", "r") as file: data = toml.load(file) sdxl_loras = [ { "image": get_image_path(item["image"]), "title": item["title"], "repo": item["repo"], "weights": item["weights"], "multiplier": item["multiplier"] if "multiplier" in item else "1.0", "sample_prompt": item["sample_prompt"], "sample_negative": item["sample_negative"], } for item in data["data"] ] saved_names = [hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras] with gr.Blocks(css="style.css", theme="NoCrypt/miku@1.2.1") as demo: title = gr.HTML( f"""