import spaces
import os
from stablepy import Model_Diffusers
from stablepy.diffusers_vanilla.model import scheduler_names
from stablepy.diffusers_vanilla.style_prompt_config import STYLE_NAMES
import torch
import re
import shutil
import random
from stablepy import (
    CONTROLNET_MODEL_IDS,
    VALID_TASKS,
    T2I_PREPROCESSOR_NAME,
    FLASH_LORA,
    SCHEDULER_CONFIG_MAP,
    scheduler_names,
    IP_ADAPTER_MODELS,
    IP_ADAPTERS_SD,
    IP_ADAPTERS_SDXL,
    REPO_IMAGE_ENCODER,
    ALL_PROMPT_WEIGHT_OPTIONS,
    SD15_TASKS,
    SDXL_TASKS,
)
import urllib.parse
import gradio as gr
from PIL import Image
import IPython.display
import time, json
from IPython.utils import capture
import logging
logging.getLogger("diffusers").setLevel(logging.ERROR)
import diffusers
diffusers.utils.logging.set_verbosity(40)
import warnings
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")
from stablepy import logger
logger.setLevel(logging.CRITICAL)

from env import (
    hf_token,
    hf_read_token, # to use only for private repos
    CIVITAI_API_KEY,
    HF_LORA_PRIVATE_REPOS1,
    HF_LORA_PRIVATE_REPOS2,
    HF_LORA_ESSENTIAL_PRIVATE_REPO,
    HF_VAE_PRIVATE_REPO,
    HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO,
    HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO,
    directory_models,
    directory_loras,
    directory_vaes,
    directory_embeds,
    directory_embeds_sdxl,
    directory_embeds_positive_sdxl,
    load_diffusers_format_model,
    download_model_list,
    download_lora_list,
    download_vae_list,
    download_embeds,
)

preprocessor_controlnet = {
  "openpose": [
    "Openpose",
    "None",
  ],
  "scribble": [
    "HED",
    "Pidinet",
    "None",
  ],
  "softedge": [
    "Pidinet",
    "HED",
    "HED safe",
    "Pidinet safe",
    "None",
  ],
  "segmentation": [
    "UPerNet",
    "None",
  ],
  "depth": [
    "DPT",
    "Midas",
    "None",
  ],
  "normalbae": [
    "NormalBae",
    "None",
  ],
  "lineart": [
    "Lineart",
    "Lineart coarse",
    "Lineart (anime)",
    "None",
    "None (anime)",
  ],
  "shuffle": [
    "ContentShuffle",
    "None",
  ],
  "canny": [
    "Canny"
  ],
  "mlsd": [
    "MLSD"
  ],
  "ip2p": [
    "ip2p"
  ],
}

task_stablepy = {
    'txt2img': 'txt2img',
    'img2img': 'img2img',
    'inpaint': 'inpaint',
    # 'canny T2I Adapter': 'sdxl_canny_t2i',  # NO HAVE STEP CALLBACK PARAMETERS SO NOT WORKS WITH DIFFUSERS 0.29.0
    # 'sketch  T2I Adapter': 'sdxl_sketch_t2i',
    # 'lineart  T2I Adapter': 'sdxl_lineart_t2i',
    # 'depth-midas  T2I Adapter': 'sdxl_depth-midas_t2i',
    # 'openpose  T2I Adapter': 'sdxl_openpose_t2i',
    'openpose ControlNet': 'openpose',
    'canny ControlNet': 'canny',
    'mlsd ControlNet': 'mlsd',
    'scribble ControlNet': 'scribble',
    'softedge ControlNet': 'softedge',
    'segmentation ControlNet': 'segmentation',
    'depth ControlNet': 'depth',
    'normalbae ControlNet': 'normalbae',
    'lineart ControlNet': 'lineart',
    # 'lineart_anime ControlNet': 'lineart_anime',
    'shuffle ControlNet': 'shuffle',
    'ip2p ControlNet': 'ip2p',
    'optical pattern ControlNet': 'pattern',
    'tile realistic': 'sdxl_tile_realistic',
}

task_model_list = list(task_stablepy.keys())


def download_things(directory, url, hf_token="", civitai_api_key=""):
    url = url.strip()
    
    if "drive.google.com" in url:
        original_dir = os.getcwd()
        os.chdir(directory)
        os.system(f"gdown --fuzzy {url}")
        os.chdir(original_dir)
    elif "huggingface.co" in url:
        url = url.replace("?download=true", "")
        # url = urllib.parse.quote(url, safe=':/')  # fix encoding
        if "/blob/" in url:
            url = url.replace("/blob/", "/resolve/")
        user_header = f'"Authorization: Bearer {hf_token}"'
        if hf_token:
            os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}")
        else:
            os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory}  -o {url.split('/')[-1]}")
    elif "civitai.com" in url:
        if "?" in url:
            url = url.split("?")[0]
        if civitai_api_key:
            url = url + f"?token={civitai_api_key}"
            os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
        else:
            print("\033[91mYou need an API key to download Civitai models.\033[0m")
    else:
        os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")


def get_model_list(directory_path):
    model_list = []
    valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'}

    for filename in os.listdir(directory_path):
        if os.path.splitext(filename)[1] in valid_extensions:
            name_without_extension = os.path.splitext(filename)[0]
            file_path = os.path.join(directory_path, filename)
            # model_list.append((name_without_extension, file_path))
            model_list.append(file_path)
            print('\033[34mFILE: ' + file_path + '\033[0m')
    return model_list


def process_string(input_string):
    parts = input_string.split('/')

    if len(parts) == 2:
        first_element = parts[1]
        complete_string = input_string
        result = (first_element, complete_string)
        return result
    else:
        return None

## BEGIN MOD
from modutils import (
    to_list,
    list_uniq,
    list_sub,
    get_model_id_list,
    get_tupled_embed_list,
    get_tupled_model_list,
    get_lora_model_list,
    download_private_repo,
)

# - **Download Models**
download_model = ", ".join(download_model_list)
# - **Download VAEs**
download_vae = ", ".join(download_vae_list)
# - **Download LoRAs**
download_lora = ", ".join(download_lora_list)

#download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True)
download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False)

load_diffusers_format_model = list_uniq(load_diffusers_format_model + get_model_id_list())
## END MOD

CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
hf_token = os.environ.get("HF_TOKEN")

# Download stuffs
for url in [url.strip() for url in download_model.split(',')]:
    if not os.path.exists(f"./models/{url.split('/')[-1]}"):
        download_things(directory_models, url, hf_token, CIVITAI_API_KEY)
for url in [url.strip() for url in download_vae.split(',')]:
    if not os.path.exists(f"./vaes/{url.split('/')[-1]}"):
        download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY)
for url in [url.strip() for url in download_lora.split(',')]:
    if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
        download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)

# Download Embeddings
for url_embed in download_embeds:
    if not os.path.exists(f"./embedings/{url_embed.split('/')[-1]}"):
        download_things(directory_embeds, url_embed, hf_token, CIVITAI_API_KEY)

# Build list models
embed_list = get_model_list(directory_embeds)
model_list = get_model_list(directory_models)
model_list = load_diffusers_format_model + model_list
## BEGIN MOD
lora_model_list = get_lora_model_list()
vae_model_list = get_model_list(directory_vaes)
vae_model_list.insert(0, "None")

#download_private_repo(HF_SDXL_EMBEDS_NEGATIVE_PRIVATE_REPO, directory_embeds_sdxl, False)
#download_private_repo(HF_SDXL_EMBEDS_POSITIVE_PRIVATE_REPO, directory_embeds_positive_sdxl, False)
embed_sdxl_list = get_model_list(directory_embeds_sdxl) + get_model_list(directory_embeds_positive_sdxl)

def get_embed_list(pipeline_name):
    return get_tupled_embed_list(embed_sdxl_list if pipeline_name == "StableDiffusionXLPipeline" else embed_list)


## END MOD

print('\033[33m🏁 Download and listing of valid models completed.\033[0m')

upscaler_dict_gui = {
    None : None,
    "Lanczos" : "Lanczos",
    "Nearest" : "Nearest",
    "RealESRGAN_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
    "RealESRNet_x4plus" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
    "RealESRGAN_x4plus_anime_6B": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
    "RealESRGAN_x2plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
    "realesr-animevideov3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
    "realesr-general-x4v3": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
    "realesr-general-wdn-x4v3" : "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
    "4x-UltraSharp" : "https://huggingface.co/Shandypur/ESRGAN-4x-UltraSharp/resolve/main/4x-UltraSharp.pth",
    "4x_foolhardy_Remacri" : "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
    "Remacri4xExtraSmoother" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
    "AnimeSharp4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/AnimeSharp%204x.pth",
    "lollypop" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
    "RealisticRescaler4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/RealisticRescaler%204x.pth",
    "NickelbackFS4x" : "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth"
}


def extract_parameters(input_string):
    parameters = {}
    input_string = input_string.replace("\n", "")

    if not "Negative prompt:" in input_string:
        print("Negative prompt not detected")
        parameters["prompt"] = input_string
        return parameters

    parm = input_string.split("Negative prompt:")
    parameters["prompt"] = parm[0]
    if not "Steps:" in parm[1]:
        print("Steps not detected")
        parameters["neg_prompt"] = parm[1]
        return parameters
    parm = parm[1].split("Steps:")
    parameters["neg_prompt"] = parm[0]
    input_string = "Steps:" + parm[1]

    # Extracting Steps
    steps_match = re.search(r'Steps: (\d+)', input_string)
    if steps_match:
        parameters['Steps'] = int(steps_match.group(1))

    # Extracting Size
    size_match = re.search(r'Size: (\d+x\d+)', input_string)
    if size_match:
        parameters['Size'] = size_match.group(1)
        width, height = map(int, parameters['Size'].split('x'))
        parameters['width'] = width
        parameters['height'] = height

    # Extracting other parameters
    other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
    for param in other_parameters:
        parameters[param[0]] = param[1].strip('"')

    return parameters


## BEGIN MOD
class GuiSD:
    def __init__(self):
        self.model = None
    
        print("Loading model...")
        self.model = Model_Diffusers(
            base_model_id="cagliostrolab/animagine-xl-3.1",
            task_name="txt2img",
            vae_model=None,
            type_model_precision=torch.float16,
            retain_task_model_in_cache=False,
        )

    def infer_short(self, model, pipe_params, progress=gr.Progress(track_tqdm=True)):
        progress(0, desc="Start inference...")
        images, image_list = model(**pipe_params)
        progress(1, desc="Inference completed.")
        if not isinstance(images, list): images = [images]
        img = []
        for image in images:
            img.append((image, None))
        return img

    def load_new_model(self, model_name, vae_model, task, progress=gr.Progress(track_tqdm=True)):

        yield f"Loading model: {model_name}"
        
        vae_model = vae_model if vae_model != "None" else None

        if model_name in model_list:
            model_is_xl = "xl" in model_name.lower()
            sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
            model_type = "SDXL" if model_is_xl else "SD 1.5"
            incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)

            if incompatible_vae:
                vae_model = None

        
        self.model.load_pipe(
            model_name,
            task_name=task_stablepy[task],
            vae_model=vae_model if vae_model != "None" else None,
            type_model_precision=torch.float16,
            retain_task_model_in_cache=False,
        )
        yield f"Model loaded: {model_name}"

    @spaces.GPU
    def generate_pipeline(
        self,
        prompt,
        neg_prompt,
        num_images,
        steps,
        cfg,
        clip_skip,
        seed,
        lora1,
        lora_scale1,
        lora2,
        lora_scale2,
        lora3,
        lora_scale3,
        lora4,
        lora_scale4,
        lora5,
        lora_scale5,
        sampler,
        img_height,
        img_width,
        model_name,
        vae_model,
        task,
        image_control,
        preprocessor_name,
        preprocess_resolution,
        image_resolution,
        style_prompt,  # list []
        style_json_file,
        image_mask,
        strength,
        low_threshold,
        high_threshold,
        value_threshold,
        distance_threshold,
        controlnet_output_scaling_in_unet,
        controlnet_start_threshold,
        controlnet_stop_threshold,
        textual_inversion,
        syntax_weights,
        upscaler_model_path,
        upscaler_increases_size,
        esrgan_tile,
        esrgan_tile_overlap,
        hires_steps,
        hires_denoising_strength,
        hires_sampler,
        hires_prompt,
        hires_negative_prompt,
        hires_before_adetailer,
        hires_after_adetailer,
        loop_generation,
        leave_progress_bar,
        disable_progress_bar,
        image_previews,
        display_images,
        save_generated_images,
        image_storage_location,
        retain_compel_previous_load,
        retain_detailfix_model_previous_load,
        retain_hires_model_previous_load,
        t2i_adapter_preprocessor,
        t2i_adapter_conditioning_scale,
        t2i_adapter_conditioning_factor,
        xformers_memory_efficient_attention,
        freeu,
        generator_in_cpu,
        adetailer_inpaint_only,
        adetailer_verbose,
        adetailer_sampler,
        adetailer_active_a,
        prompt_ad_a,
        negative_prompt_ad_a,
        strength_ad_a,
        face_detector_ad_a,
        person_detector_ad_a,
        hand_detector_ad_a,
        mask_dilation_a,
        mask_blur_a,
        mask_padding_a,
        adetailer_active_b,
        prompt_ad_b,
        negative_prompt_ad_b,
        strength_ad_b,
        face_detector_ad_b,
        person_detector_ad_b,
        hand_detector_ad_b,
        mask_dilation_b,
        mask_blur_b,
        mask_padding_b,
        retain_task_cache_gui,
        image_ip1,
        mask_ip1,
        model_ip1,
        mode_ip1,
        scale_ip1,
        image_ip2,
        mask_ip2,
        model_ip2,
        mode_ip2,
        scale_ip2,
        progress=gr.Progress(track_tqdm=True),
    ):
        progress(0, desc="Preparing inference...")
        
        vae_model = vae_model if vae_model != "None" else None
        loras_list = [lora1, lora2, lora3, lora4, lora5]
        vae_msg = f"VAE: {vae_model}" if vae_model else ""
        msg_lora = []

## BEGIN MOD
        prompt, neg_prompt = insert_model_recom_prompt(prompt, neg_prompt, model_name)
        global lora_model_list
        lora_model_list = get_lora_model_list()
## END MOD
        
        if model_name in model_list:
            model_is_xl = "xl" in model_name.lower()
            sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
            model_type = "SDXL" if model_is_xl else "SD 1.5"
            incompatible_vae = (model_is_xl and vae_model and not sdxl_in_vae) or (not model_is_xl and sdxl_in_vae)

            if incompatible_vae:
                msg_inc_vae = (
                    f"The selected VAE is for a { 'SD 1.5' if model_is_xl else 'SDXL' } model, but you"
                    f" are using a { model_type } model. The default VAE "
                    "will be used."
                )
                gr.Info(msg_inc_vae)
                vae_msg = msg_inc_vae
                vae_model = None

            for la in loras_list:
                if la is not None and la != "None" and la in lora_model_list:
                    print(la)
                    lora_type = ("animetarot" in la.lower() or "Hyper-SD15-8steps".lower() in la.lower())
                    if (model_is_xl and lora_type) or (not model_is_xl and not lora_type):
                        msg_inc_lora = f"The LoRA {la} is for { 'SD 1.5' if model_is_xl else 'SDXL' }, but you are using { model_type }."
                        gr.Info(msg_inc_lora)
                        msg_lora.append(msg_inc_lora)

        task = task_stablepy[task]

        params_ip_img = []
        params_ip_msk = []
        params_ip_model = []
        params_ip_mode = []
        params_ip_scale = []

        all_adapters = [
            (image_ip1, mask_ip1, model_ip1, mode_ip1, scale_ip1),
            (image_ip2, mask_ip2, model_ip2, mode_ip2, scale_ip2),
        ]

        for imgip, mskip, modelip, modeip, scaleip in all_adapters:
            if imgip:
                params_ip_img.append(imgip)
                if mskip:
                    params_ip_msk.append(mskip)
                params_ip_model.append(modelip)
                params_ip_mode.append(modeip)
                params_ip_scale.append(scaleip)

        # First load
        model_precision = torch.float16
        if not self.model:
            from stablepy import Model_Diffusers

            print("Loading model...")
            self.model = Model_Diffusers(
                base_model_id=model_name,
                task_name=task,
                vae_model=vae_model if vae_model != "None" else None,
                type_model_precision=model_precision,
                retain_task_model_in_cache=retain_task_cache_gui,
            )

        if task != "txt2img" and not image_control:
            raise ValueError("No control image found: To use this function, you have to upload an image in 'Image ControlNet/Inpaint/Img2img'")

        if task == "inpaint" and not image_mask:
            raise ValueError("No mask image found: Specify one in 'Image Mask'")

        if upscaler_model_path in [None, "Lanczos", "Nearest"]:
            upscaler_model = upscaler_model_path
        else:
            directory_upscalers = 'upscalers'
            os.makedirs(directory_upscalers, exist_ok=True)

            url_upscaler = upscaler_dict_gui[upscaler_model_path]

            if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
                download_things(directory_upscalers, url_upscaler, hf_token)

            upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"

        logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)

        print("Config model:", model_name, vae_model, loras_list)

        self.model.load_pipe(
            model_name,
            task_name=task,
            vae_model=vae_model if vae_model != "None" else None,
            type_model_precision=model_precision,
            retain_task_model_in_cache=retain_task_cache_gui,
        )

## BEGIN MOD
#        if textual_inversion and self.model.class_name == "StableDiffusionXLPipeline":
#            print("No Textual inversion for SDXL")
## END MOD

        adetailer_params_A = {
            "face_detector_ad" : face_detector_ad_a,
            "person_detector_ad" : person_detector_ad_a,
            "hand_detector_ad" : hand_detector_ad_a,
            "prompt": prompt_ad_a,
            "negative_prompt" : negative_prompt_ad_a,
            "strength" : strength_ad_a,
            # "image_list_task" : None,
            "mask_dilation" : mask_dilation_a,
            "mask_blur" : mask_blur_a,
            "mask_padding" : mask_padding_a,
            "inpaint_only" : adetailer_inpaint_only,
            "sampler" : adetailer_sampler,
        }

        adetailer_params_B = {
            "face_detector_ad" : face_detector_ad_b,
            "person_detector_ad" : person_detector_ad_b,
            "hand_detector_ad" : hand_detector_ad_b,
            "prompt": prompt_ad_b,
            "negative_prompt" : negative_prompt_ad_b,
            "strength" : strength_ad_b,
            # "image_list_task" : None,
            "mask_dilation" : mask_dilation_b,
            "mask_blur" : mask_blur_b,
            "mask_padding" : mask_padding_b,
        }
        pipe_params = {
            "prompt": prompt,
            "negative_prompt": neg_prompt,
            "img_height": img_height,
            "img_width": img_width,
            "num_images": num_images,
            "num_steps": steps,
            "guidance_scale": cfg,
            "clip_skip": clip_skip,
            "seed": seed,
            "image": image_control,
            "preprocessor_name": preprocessor_name,
            "preprocess_resolution": preprocess_resolution,
            "image_resolution": image_resolution,
            "style_prompt": style_prompt if style_prompt else "",
            "style_json_file": "",
            "image_mask": image_mask,  # only for Inpaint
            "strength": strength,  # only for Inpaint or ...
            "low_threshold": low_threshold,
            "high_threshold": high_threshold,
            "value_threshold": value_threshold,
            "distance_threshold": distance_threshold,
            "lora_A": lora1 if lora1 != "None" else None,
            "lora_scale_A": lora_scale1,
            "lora_B": lora2 if lora2 != "None" else None,
            "lora_scale_B": lora_scale2,
            "lora_C": lora3 if lora3 != "None" else None,
            "lora_scale_C": lora_scale3,
            "lora_D": lora4 if lora4 != "None" else None,
            "lora_scale_D": lora_scale4,
            "lora_E": lora5 if lora5 != "None" else None,
            "lora_scale_E": lora_scale5,
## BEGIN MOD
            "textual_inversion": get_embed_list(self.model.class_name) if textual_inversion else [],
## END MOD
            "syntax_weights": syntax_weights,  # "Classic"
            "sampler": sampler,
            "xformers_memory_efficient_attention": xformers_memory_efficient_attention,
            "gui_active": True,
            "loop_generation": loop_generation,
            "controlnet_conditioning_scale": float(controlnet_output_scaling_in_unet),
            "control_guidance_start": float(controlnet_start_threshold),
            "control_guidance_end": float(controlnet_stop_threshold),
            "generator_in_cpu": generator_in_cpu,
            "FreeU": freeu,
            "adetailer_A": adetailer_active_a,
            "adetailer_A_params": adetailer_params_A,
            "adetailer_B": adetailer_active_b,
            "adetailer_B_params": adetailer_params_B,
            "leave_progress_bar": leave_progress_bar,
            "disable_progress_bar": disable_progress_bar,
            "image_previews": image_previews,
            "display_images": display_images,
            "save_generated_images": save_generated_images,
            "image_storage_location": image_storage_location,
            "retain_compel_previous_load": retain_compel_previous_load,
            "retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
            "retain_hires_model_previous_load": retain_hires_model_previous_load,
            "t2i_adapter_preprocessor": t2i_adapter_preprocessor,
            "t2i_adapter_conditioning_scale": float(t2i_adapter_conditioning_scale),
            "t2i_adapter_conditioning_factor": float(t2i_adapter_conditioning_factor),
            "upscaler_model_path": upscaler_model,
            "upscaler_increases_size": upscaler_increases_size,
            "esrgan_tile": esrgan_tile,
            "esrgan_tile_overlap": esrgan_tile_overlap,
            "hires_steps": hires_steps,
            "hires_denoising_strength": hires_denoising_strength,
            "hires_prompt": hires_prompt,
            "hires_negative_prompt": hires_negative_prompt,
            "hires_sampler": hires_sampler,
            "hires_before_adetailer": hires_before_adetailer,
            "hires_after_adetailer": hires_after_adetailer,
            "ip_adapter_image": params_ip_img,
            "ip_adapter_mask": params_ip_msk,
            "ip_adapter_model": params_ip_model,
            "ip_adapter_mode": params_ip_mode,
            "ip_adapter_scale": params_ip_scale,
        }

        # Maybe fix lora issue: 'Cannot copy out of meta tensor; no data!''
        self.model.pipe.to("cuda:0" if torch.cuda.is_available() else "cpu")

        progress(1, desc="Inference preparation completed. Starting inference...")

        info_state = f"PROCESSING "
        info_state += ">"
        info_state = f"COMPLETED. Seeds: {str(seed)}"
        if vae_msg:
            info_state = info_state + "<br>" + vae_msg
        if msg_lora:
            info_state = info_state + "<br>" + "<br>".join(msg_lora)
        return self.infer_short(self.model, pipe_params), info_state
## END MOD


from pathlib import Path
from modutils import (
    safe_float,
    escape_lora_basename,
    to_lora_key,
    to_lora_path,
    get_local_model_list,
    get_private_lora_model_lists,
    get_valid_lora_name,
    get_valid_lora_path,
    get_valid_lora_wt,
    get_lora_info,
    normalize_prompt_list,
    get_civitai_info,
    search_lora_on_civitai,
)

sd_gen = GuiSD()
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
           model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
           lora3 = None, lora3_wt = 1.0, lora4 = None, lora4_wt = 1.0, lora5 = None, lora5_wt = 1.0,
           sampler = "Euler a", vae = None, progress=gr.Progress(track_tqdm=True)):
    import PIL
    import numpy as np
    MAX_SEED = np.iinfo(np.int32).max

    images: list[tuple[PIL.Image.Image, str | None]] = []
    info: str = ""
    progress(0, desc="Preparing...")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator().manual_seed(seed).seed()

    prompt, negative_prompt = insert_model_recom_prompt(prompt, negative_prompt, model_name)
    progress(0.5, desc="Preparing...")
    lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt = \
        set_prompt_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt)
    lora1 = get_valid_lora_path(lora1)
    lora2 = get_valid_lora_path(lora2)
    lora3 = get_valid_lora_path(lora3)
    lora4 = get_valid_lora_path(lora4)
    lora5 = get_valid_lora_path(lora5)
    progress(1, desc="Preparation completed. Starting inference preparation...")

    sd_gen.load_new_model(model_name, vae, task_model_list[0])
    images, info = sd_gen.generate_pipeline(prompt, negative_prompt, 1, num_inference_steps,
        guidance_scale, True, generator, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt,
        lora4, lora4_wt, lora5, lora5_wt, sampler,
        height, width, model_name, vae, task_model_list[0], None, "Canny", 512, 1024,
        None, None, None, 0.35, 100, 200, 0.1, 0.1, 1.0, 0., 1., False, "Classic", None,
        1.0, 100, 10, 30, 0.55, "Use same sampler", "", "",
        False, True, 1, True, False, False, False, False, "./images", False, False, False, True, 1, 0.55,
        False, False, False, True, False, "Use same sampler", False, "", "", 0.35, True, True, False, 4, 4, 32,
        False, "", "", 0.35, True, True, False, 4, 4, 32,
        True, None, None, "plus_face", "original", 0.7, None, None, "base", "style", 0.7
    )

    progress(1, desc="Inference completed.")
    output_image = images[0][0] if images else None

    return output_image


@spaces.GPU
def _infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
            model_name = load_diffusers_format_model[0], lora1 = None, lora1_wt = 1.0, lora2 = None, lora2_wt = 1.0,
            lora3 = None, lora3_wt = 1.0, lora4 = None, lora4_wt = 1.0, lora5 = None, lora5_wt = 1.0,
            sampler = "Euler a", vae = None, progress=gr.Progress(track_tqdm=True)):
    return gr.update(visible=True)


infer.zerogpu = True
_infer.zerogpu = True


def pass_result(result):
    return result


def get_samplers():
    return scheduler_names


def get_vaes():
    return vae_model_list


show_diffusers_model_list_detail = False
cached_diffusers_model_tupled_list = get_tupled_model_list(load_diffusers_format_model)
def get_diffusers_model_list():
    if show_diffusers_model_list_detail:
        return cached_diffusers_model_tupled_list
    else:
        return load_diffusers_format_model


def enable_diffusers_model_detail(is_enable: bool = False, model_name: str = ""):
    global show_diffusers_model_list_detail
    show_diffusers_model_list_detail = is_enable
    new_value = model_name
    index = 0
    if model_name in set(load_diffusers_format_model):
        index = load_diffusers_format_model.index(model_name)
    if is_enable:
        new_value = cached_diffusers_model_tupled_list[index][1]
    else:
        new_value = load_diffusers_format_model[index]
    return gr.update(value=is_enable), gr.update(value=new_value, choices=get_diffusers_model_list())


def get_t2i_model_info(repo_id: str):
    from huggingface_hub import HfApi
    api = HfApi()
    try:
        if " " in repo_id or not api.repo_exists(repo_id): return ""
        model = api.model_info(repo_id=repo_id)
    except Exception as e:
        print(f"Error: Failed to get {repo_id}'s info. ")
        return ""
    if model.private or model.gated: return ""
    tags = model.tags
    info = []
    url = f"https://huggingface.co/{repo_id}/"
    if not 'diffusers' in tags: return ""
    if 'diffusers:StableDiffusionXLPipeline' in tags:
        info.append("SDXL")
    elif 'diffusers:StableDiffusionPipeline' in tags:
        info.append("SD1.5")
    if model.card_data and model.card_data.tags:
        info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl']))
    info.append(f"DLs: {model.downloads}")
    info.append(f"likes: {model.likes}")
    info.append(model.last_modified.strftime("lastmod: %Y-%m-%d"))
    md = f"Model Info: {', '.join(info)}, [Model Repo]({url})"
    return gr.update(value=md)


def load_model_prompt_dict():
    import json
    dict = {}
    try:
        with open('model_dict.json', encoding='utf-8') as f:
            dict = json.load(f)
    except Exception:
        pass
    return dict


model_prompt_dict = load_model_prompt_dict()


model_recom_prompt_enabled = True
animagine_ps = to_list("masterpiece, best quality, very aesthetic, absurdres")
animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
pony_ps = to_list("score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
pony_nps = to_list("source_pony, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
other_ps = to_list("anime artwork, anime style, studio anime, highly detailed, cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed")
other_nps = to_list("photo, deformed, black and white, realism, disfigured, low contrast, drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly")
default_ps = to_list("highly detailed, masterpiece, best quality, very aesthetic, absurdres")
default_nps = to_list("score_6, score_5, score_4, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
def insert_model_recom_prompt(prompt: str = "", neg_prompt: str = "", model_name: str = "None"):
    if not model_recom_prompt_enabled or not model_name: return prompt, neg_prompt
    prompts = to_list(prompt)
    neg_prompts = to_list(neg_prompt)
    prompts = list_sub(prompts, animagine_ps + pony_ps + other_ps)
    neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + other_nps)
    last_empty_p = [""] if not prompts and type != "None" else []
    last_empty_np = [""] if not neg_prompts and type != "None" else []
    ps = []
    nps = []
    if model_name in model_prompt_dict.keys(): 
        ps = to_list(model_prompt_dict[model_name]["prompt"])
        nps = to_list(model_prompt_dict[model_name]["negative_prompt"])
    else:
        ps = default_ps
        nps = default_nps
    prompts = prompts + ps
    neg_prompts = neg_prompts + nps
    prompt = ", ".join(list_uniq(prompts) + last_empty_p)
    neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)
    return prompt, neg_prompt


def enable_model_recom_prompt(is_enable: bool = True):
    global model_recom_prompt_enabled
    model_recom_prompt_enabled = is_enable
    return is_enable


private_lora_dict = {}
try:
    with open('lora_dict.json', encoding='utf-8') as f:
        d = json.load(f)
        for k, v in d.items():
            private_lora_dict[escape_lora_basename(k)] = v
except Exception:
    pass


private_lora_model_list = get_private_lora_model_lists()
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy()
loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...}
civitai_lora_last_results = {}  # {"URL to download": {search results}, ...}
all_lora_list = []


def get_all_lora_list():
    global all_lora_list
    loras = get_lora_model_list()
    all_lora_list = loras.copy()
    return loras


def get_all_lora_tupled_list():
    global loras_dict
    models = get_all_lora_list()
    if not models: return []
    tupled_list = []
    for model in models:
        #if not model: continue # to avoid GUI-related bug
        basename = Path(model).stem
        key = to_lora_key(model)
        items = None
        if key in loras_dict.keys():
            items = loras_dict.get(key, None)
        else:
            items = get_civitai_info(model)
            if items != None:
                loras_dict[key] = items
        name = basename
        value = model
        if items and items[2] != "":
            if items[1] == "Pony":
                name = f"{basename} (for {items[1]}🐴, {items[2]})"
            else:
                name = f"{basename} (for {items[1]}, {items[2]})"
        tupled_list.append((name, value))
    return tupled_list


def update_lora_dict(path: str):
    global loras_dict
    key = to_lora_key(path)
    if key in loras_dict.keys(): return
    items = get_civitai_info(path)
    if items == None: return
    loras_dict[key] = items


def download_lora(dl_urls: str):
    global loras_url_to_path_dict
    dl_path = ""
    before = get_local_model_list(directory_loras)
    urls = []
    for url in [url.strip() for url in dl_urls.split(',')]:
        local_path = f"{directory_loras}/{url.split('/')[-1]}"
        if not Path(local_path).exists():
            download_things(directory_loras, url, hf_token, CIVITAI_API_KEY)
            urls.append(url)
    after = get_local_model_list(directory_loras)
    new_files = list_sub(after, before)
    i = 0
    for file in new_files:
        path = Path(file)
        if path.exists():
            new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}')
            path.resolve().rename(new_path.resolve())
            loras_url_to_path_dict[urls[i]] = str(new_path)
            update_lora_dict(str(new_path))
            dl_path = str(new_path)
        i += 1
    return dl_path


def copy_lora(path: str, new_path: str):
    import shutil
    if path == new_path: return new_path
    cpath = Path(path)
    npath = Path(new_path)
    if cpath.exists():
        try:
            shutil.copy(str(cpath.resolve()), str(npath.resolve()))
        except Exception:
            return None
        update_lora_dict(str(npath))
        return new_path
    else:
        return None


def download_my_lora(dl_urls: str, lora1: str, lora2: str, lora3: str, lora4: str, lora5: str):
    path = download_lora(dl_urls)
    if path:
        if not lora1 or lora1 == "None":
            lora1 = path
        elif not lora2 or lora2 == "None":
            lora2 = path
        elif not lora3 or lora3 == "None":
            lora3 = path
        elif not lora4 or lora4 == "None":
            lora4 = path
        elif not lora5 or lora5 == "None":
            lora5 = path
    choices = get_all_lora_tupled_list()
    return gr.update(value=lora1, choices=choices), gr.update(value=lora2, choices=choices), gr.update(value=lora3, choices=choices),\
        gr.update(value=lora4, choices=choices), gr.update(value=lora5, choices=choices)


def set_prompt_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
    import re
    lora1 = get_valid_lora_name(lora1)
    lora2 = get_valid_lora_name(lora2)
    lora3 = get_valid_lora_name(lora3)
    lora4 = get_valid_lora_name(lora4)
    lora5 = get_valid_lora_name(lora5)
    if not "<lora" in prompt: return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt
    lora1_wt = get_valid_lora_wt(prompt, lora1, lora1_wt)
    lora2_wt = get_valid_lora_wt(prompt, lora2, lora2_wt)
    lora3_wt = get_valid_lora_wt(prompt, lora3, lora3_wt)
    lora4_wt = get_valid_lora_wt(prompt, lora4, lora4_wt)
    lora5_wt = get_valid_lora_wt(prompt, lora5, lora5_wt)
    on1, label1, tag1, md1 = get_lora_info(lora1)
    on2, label2, tag2, md2 = get_lora_info(lora2)
    on3, label3, tag3, md3 = get_lora_info(lora3)
    on4, label4, tag4, md4 = get_lora_info(lora4)
    on5, label5, tag5, md5 = get_lora_info(lora5)
    lora_paths = [lora1, lora2, lora3, lora4, lora5]
    prompts = prompt.split(",") if prompt else []
    for p in prompts:
        p = str(p).strip()
        if "<lora" in p:
            result = re.findall(r'<lora:(.+?):(.+?)>', p)
            if not result: continue
            key = result[0][0]
            wt = result[0][1]
            path = to_lora_path(key)
            if not key in loras_dict.keys() or not path:
                path = get_valid_lora_name(path)
                if not path or path == "None": continue
            if path in lora_paths:
                continue
            elif not on1:
                lora1 = path
                lora_paths = [lora1, lora2, lora3, lora4, lora5]
                lora1_wt = safe_float(wt)
                on1 = True
            elif not on2:
                lora2 = path
                lora_paths = [lora1, lora2, lora3, lora4, lora5]
                lora2_wt = safe_float(wt)
                on2 = True
            elif not on3:
                lora3 = path
                lora_paths = [lora1, lora2, lora3, lora4, lora5]
                lora3_wt = safe_float(wt)
                on3 = True
            elif not on4:
                lora4 = path
                lora_paths = [lora1, lora2, lora3, lora4, lora5]
                lora4_wt = safe_float(wt)
                on4, label4, tag4, md4 = get_lora_info(lora4)
            elif not on5:
                lora5 = path
                lora_paths = [lora1, lora2, lora3, lora4, lora5]
                lora5_wt = safe_float(wt)
                on5 = True
    return lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt


def apply_lora_prompt(prompt: str, lora_info: str):
    if lora_info == "None": return gr.update(value=prompt)
    tags = prompt.split(",") if prompt else []
    prompts = normalize_prompt_list(tags)
    lora_tag = lora_info.replace("/",",")
    lora_tags = lora_tag.split(",") if str(lora_info) != "None" else []
    lora_prompts = normalize_prompt_list(lora_tags)
    empty = [""]
    prompt = ", ".join(list_uniq(prompts + lora_prompts) + empty)
    return gr.update(value=prompt)


def update_loras(prompt, lora1, lora1_wt, lora2, lora2_wt, lora3, lora3_wt, lora4, lora4_wt, lora5, lora5_wt):
    import re
    on1, label1, tag1, md1 = get_lora_info(lora1)
    on2, label2, tag2, md2 = get_lora_info(lora2)
    on3, label3, tag3, md3 = get_lora_info(lora3)
    on4, label4, tag4, md4 = get_lora_info(lora4)
    on5, label5, tag5, md5 = get_lora_info(lora5)
    lora_paths = [lora1, lora2, lora3, lora4, lora5]
    prompts = prompt.split(",") if prompt else []
    output_prompts = []
    for p in prompts:
        p = str(p).strip()
        if "<lora" in p:
            result = re.findall(r'<lora:(.+?):(.+?)>', p)
            if not result: continue
            key = result[0][0]
            wt = result[0][1]
            path = to_lora_path(key)
            if not key in loras_dict.keys() or not path: continue
            if path in lora_paths:
                output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>")
        elif p:
            output_prompts.append(p)
    lora_prompts = []
    if on1: lora_prompts.append(f"<lora:{to_lora_key(lora1)}:{lora1_wt:.2f}>")
    if on2: lora_prompts.append(f"<lora:{to_lora_key(lora2)}:{lora2_wt:.2f}>")
    if on3: lora_prompts.append(f"<lora:{to_lora_key(lora3)}:{lora3_wt:.2f}>")
    if on4: lora_prompts.append(f"<lora:{to_lora_key(lora4)}:{lora4_wt:.2f}>")
    if on5: lora_prompts.append(f"<lora:{to_lora_key(lora5)}:{lora5_wt:.2f}>")
    output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts + [""]))
    choices = get_all_lora_tupled_list()
    return gr.update(value=output_prompt), gr.update(value=lora1, choices=choices), gr.update(value=lora1_wt),\
     gr.update(value=tag1, label=label1, visible=on1), gr.update(visible=on1), gr.update(value=md1, visible=on1),\
     gr.update(value=lora2, choices=choices), gr.update(value=lora2_wt),\
     gr.update(value=tag2, label=label2, visible=on2), gr.update(visible=on2), gr.update(value=md2, visible=on2),\
     gr.update(value=lora3, choices=choices), gr.update(value=lora3_wt),\
     gr.update(value=tag3, label=label3, visible=on3), gr.update(visible=on3), gr.update(value=md3, visible=on3),\
     gr.update(value=lora4, choices=choices), gr.update(value=lora4_wt),\
     gr.update(value=tag4, label=label4, visible=on4), gr.update(visible=on4), gr.update(value=md4, visible=on4),\
     gr.update(value=lora5, choices=choices), gr.update(value=lora5_wt),\
     gr.update(value=tag5, label=label5, visible=on5), gr.update(visible=on5), gr.update(value=md5, visible=on5)


def search_civitai_lora(query, base_model):
    global civitai_lora_last_results
    items = search_lora_on_civitai(query, base_model)
    if not items: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    civitai_lora_last_results = {}
    choices = []
    for item in items:
        base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model']
        name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})"
        value = item['dl_url']
        choices.append((name, value))
        civitai_lora_last_results[value] = item
    if not choices: return gr.update(choices=[("", "")], value="", visible=False),\
          gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True)
    result = civitai_lora_last_results.get(choices[0][1], "None")
    md = result['md'] if result else ""
    return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\
          gr.update(visible=True), gr.update(visible=True)


def select_civitai_lora(search_result):
    if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True)
    result = civitai_lora_last_results.get(search_result, "None")
    md = result['md'] if result else ""
    return gr.update(value=search_result), gr.update(value=md, visible=True)


def search_civitai_lora_json(query, base_model):
    results = {}
    items = search_lora_on_civitai(query, base_model)
    if not items: return gr.update(value=results)
    for item in items:
        results[item['dl_url']] = item
    return gr.update(value=results)


quality_prompt_list = [
    {
        "name": "None",
        "prompt": "",
        "negative_prompt": "lowres",
    },
    {
        "name": "Animagine Common",
        "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
    },
    {
        "name": "Pony Anime Common",
        "prompt": "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
    },
    {
        "name": "Pony Common",
        "prompt": "source_anime, score_9, score_8_up, score_7_up",
        "negative_prompt": "source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends",
    },
    {
        "name": "Animagine Standard v3.0",
        "prompt": "masterpiece, best quality",
        "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name",
    },
    {
        "name": "Animagine Standard v3.1",
        "prompt": "masterpiece, best quality, very aesthetic, absurdres",
        "negative_prompt": "lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]",
    },
    {
        "name": "Animagine Light v3.1",
        "prompt": "(masterpiece), best quality, very aesthetic, perfect face",
        "negative_prompt": "(low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
    },
    {
        "name": "Animagine Heavy v3.1",
        "prompt": "(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details",
        "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing",
    },
]


style_list = [
    {
        "name": "None",
        "prompt": "",
        "negative_prompt": "",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still, emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
        "negative_prompt": "cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo, 35mm photograph, film, bokeh, professional, 4k, highly detailed",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
    },
    {
        "name": "Anime",
        "prompt": "anime artwork, anime style, vibrant, studio anime, highly detailed",
        "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
    },
    {
        "name": "Manga",
        "prompt": "manga style, vibrant, high-energy, detailed, iconic, Japanese comic style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
    },
    {
        "name": "Digital Art",
        "prompt": "concept art, digital artwork, illustrative, painterly, matte painting, highly detailed",
        "negative_prompt": "photo, photorealistic, realism, ugly",
    },
    {
        "name": "Pixel art",
        "prompt": "pixel-art, low-res, blocky, pixel art style, 8-bit graphics",
        "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
    },
    {
        "name": "Fantasy art",
        "prompt": "ethereal fantasy concept art, magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
        "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
    },
    {
        "name": "Neonpunk",
        "prompt": "neonpunk style, cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
    },
    {
        "name": "3D Model",
        "prompt": "professional 3d model, octane render, highly detailed, volumetric, dramatic lighting",
        "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
    },
]


preset_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
preset_quality = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in quality_prompt_list}


def process_style_prompt(prompt: str, neg_prompt: str, styles_key: str = "None", quality_key: str = "None"):
    def to_list(s):
        return [x.strip() for x in s.split(",") if not s == ""]
    
    def list_sub(a, b):
        return [e for e in a if e not in b]
    
    def list_uniq(l):
        return sorted(set(l), key=l.index)

    animagine_ps = to_list("anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres")
    animagine_nps = to_list("lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]")
    pony_ps = to_list("source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres")
    pony_nps = to_list("source_pony, source_furry, source_cartoon, score_6, score_5, score_4, busty, ugly face, mutated hands, low res, blurry face, black and white, the simpsons, overwatch, apex legends")
    prompts = to_list(prompt)
    neg_prompts = to_list(neg_prompt)

    all_styles_ps = []
    all_styles_nps = []
    for d in style_list:
        all_styles_ps.extend(to_list(str(d.get("prompt", ""))))
        all_styles_nps.extend(to_list(str(d.get("negative_prompt", ""))))

    all_quality_ps = []
    all_quality_nps = []
    for d in quality_prompt_list:
        all_quality_ps.extend(to_list(str(d.get("prompt", ""))))
        all_quality_nps.extend(to_list(str(d.get("negative_prompt", ""))))

    quality_ps = to_list(preset_quality[quality_key][0])
    quality_nps = to_list(preset_quality[quality_key][1])
    styles_ps = to_list(preset_styles[styles_key][0])
    styles_nps = to_list(preset_styles[styles_key][1])

    prompts = list_sub(prompts, animagine_ps + pony_ps + all_styles_ps + all_quality_ps)
    neg_prompts = list_sub(neg_prompts, animagine_nps + pony_nps + all_styles_nps + all_quality_nps)

    last_empty_p = [""] if not prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []
    last_empty_np = [""] if not neg_prompts and type != "None" and type != "Auto" and styles_key != "None" and quality_key != "None" else []

    if type == "Animagine":
        prompts = prompts + animagine_ps
        neg_prompts = neg_prompts + animagine_nps
    elif type == "Pony":
        prompts = prompts + pony_ps
        neg_prompts = neg_prompts + pony_nps

    prompts = prompts + styles_ps + quality_ps
    neg_prompts = neg_prompts + styles_nps + quality_nps

    prompt = ", ".join(list_uniq(prompts) + last_empty_p)
    neg_prompt = ", ".join(list_uniq(neg_prompts) + last_empty_np)

    return gr.update(value=prompt), gr.update(value=neg_prompt)