import spaces
import os
import torch
import logging
import random
import gradio as gr
import diffusers
from models.upscaler import upscaler_dict_gui
from stablepy import Model_Diffusers
from utils.download_utils import download_things

logging.getLogger("diffusers").setLevel(logging.ERROR)
diffusers.utils.logging.set_verbosity(40)

hf_token: str = os.environ.get("HF_TOKEN")


class GuiSD:
    def __init__(self,
                 model_list,
                 task_stablepy,
                 lora_model_list,
                 embed_list,
                 stream=True):
        self.model = None
        print("Loading model...")
        self.model = Model_Diffusers(
            base_model_id="models/animaPencilXL_v500.safetensors",
            task_name="txt2img",
            vae_model="vaes/sdXL_v10VAEFix.safetensors",
            type_model_precision=torch.float16,
            retain_task_model_in_cache=False,
        )
        self.model_list = model_list
        self.task_stablepy = task_stablepy
        self.lora_model_list = lora_model_list
        self.embed_list = embed_list
        self.stream = stream

    def load_new_model(
            self,
            model_name,
            vae_model,
            task,
            progress=gr.Progress(track_tqdm=True)):
        """
        :param model_name:
        :param vae_model:
        :param task:
        :param progress:
        """
        yield f"Loading model: {model_name}"

        vae_model = vae_model if vae_model != "None" else None

        if model_name in self.model_list:
            model_is_xl = "xl" in model_name.lower()
            sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
            incompatible_vae = ((
                                        model_is_xl and
                                        vae_model and
                                        not sdxl_in_vae) or
                                (not model_is_xl and
                                 sdxl_in_vae))

            if incompatible_vae:
                vae_model = None

        self.model.load_pipe(
            model_name,
            task_name=self.task_stablepy[task],
            vae_model=vae_model if vae_model != "None" else None,
            type_model_precision=torch.float16,
            retain_task_model_in_cache=False,
        )
        yield f"Model loaded: {model_name}"

    @spaces.GPU
    def generate_pipeline(
            self,
            prompt,
            neg_prompt,
            num_images,
            steps,
            cfg,
            clip_skip,
            seed,
            lora1,
            lora_scale1,
            lora2,
            lora_scale2,
            lora3,
            lora_scale3,
            lora4,
            lora_scale4,
            lora5,
            lora_scale5,
            sampler,
            img_height,
            img_width,
            model_name,
            vae_model,
            task,
            image_control,
            preprocessor_name,
            preprocess_resolution,
            image_resolution,
            style_prompt,  # list []
            style_json_file,
            image_mask,
            strength,
            low_threshold,
            high_threshold,
            value_threshold,
            distance_threshold,
            controlnet_output_scaling_in_unet,
            controlnet_start_threshold,
            controlnet_stop_threshold,
            textual_inversion,
            syntax_weights,
            upscaler_model_path,
            upscaler_increases_size,
            esrgan_tile,
            esrgan_tile_overlap,
            hires_steps,
            hires_denoising_strength,
            hires_sampler,
            hires_prompt,
            hires_negative_prompt,
            hires_before_adetailer,
            hires_after_adetailer,
            loop_generation,
            leave_progress_bar,
            disable_progress_bar,
            image_previews,
            display_images,
            save_generated_images,
            image_storage_location,
            retain_compel_previous_load,
            retain_detailfix_model_previous_load,
            retain_hires_model_previous_load,
            t2i_adapter_preprocessor,
            t2i_adapter_conditioning_scale,
            t2i_adapter_conditioning_factor,
            xformers_memory_efficient_attention,
            freeu,
            generator_in_cpu,
            adetailer_inpaint_only,
            adetailer_verbose,
            adetailer_sampler,
            adetailer_active_a,
            prompt_ad_a,
            negative_prompt_ad_a,
            strength_ad_a,
            face_detector_ad_a,
            person_detector_ad_a,
            hand_detector_ad_a,
            mask_dilation_a,
            mask_blur_a,
            mask_padding_a,
            adetailer_active_b,
            prompt_ad_b,
            negative_prompt_ad_b,   
            strength_ad_b,
            face_detector_ad_b,
            person_detector_ad_b,
            hand_detector_ad_b,
            mask_dilation_b,
            mask_blur_b,
            mask_padding_b,
            retain_task_cache_gui,
            image_ip1,
            mask_ip1,
            model_ip1,
            mode_ip1,
            scale_ip1,
            image_ip2,
            mask_ip2,
            model_ip2,
            mode_ip2,
            scale_ip2):
        vae_model = vae_model if vae_model != "None" else None
        loras_list: list = [lora1, lora2, lora3, lora4, lora5]
        vae_msg: str = f"VAE: {vae_model}" if vae_model else ""
        msg_lora: list = []

        if model_name in self.model_list:
            model_is_xl = "xl" in model_name.lower()
            sdxl_in_vae = vae_model and "sdxl" in vae_model.lower()
            model_type = "SDXL" if model_is_xl else "SD 1.5"
            incompatible_vae = ((model_is_xl and
                                 vae_model and
                                 not sdxl_in_vae) or
                                (not model_is_xl and
                                 sdxl_in_vae))

            if incompatible_vae:
                msg_inc_vae = (
                    f"The selected VAE is for a {'SD 1.5' if model_is_xl else 'SDXL'} model, but you"
                    f" are using a {model_type} model. The default VAE "
                    "will be used."
                )
                gr.Info(msg_inc_vae)
                vae_msg = msg_inc_vae
                vae_model = None

            for la in loras_list:
                if la is None or la == "None" or la not in self.lora_model_list:
                    continue

                print(la)
                lora_type = ("animetarot" in la.lower() or "Hyper-SD15-8steps".lower() in la.lower())
                if (model_is_xl and lora_type) or (not model_is_xl and not lora_type):
                    msg_inc_lora = f"The LoRA {la} is for {'SD 1.5' if model_is_xl else 'SDXL'}, but you are using {model_type}."
                    gr.Info(msg_inc_lora)
                    msg_lora.append(msg_inc_lora)

        task = self.task_stablepy[task]

        params_ip_img: list = []
        params_ip_msk: list = []
        params_ip_model: list = []
        params_ip_mode: list = []
        params_ip_scale: list = []

        all_adapters = [
            (image_ip1,
             mask_ip1,
             model_ip1,
             mode_ip1,
             scale_ip1),
            (image_ip2,
             mask_ip2,
             model_ip2,
             mode_ip2,
             scale_ip2),
        ]

        for (imgip,
             mskip,
             modelip,
             modeip,
             scaleip) in all_adapters:
            if imgip:
                params_ip_img.append(imgip)
                if mskip:
                    params_ip_msk.append(mskip)
                params_ip_model.append(modelip)
                params_ip_mode.append(modeip)
                params_ip_scale.append(scaleip)

        # First load
        model_precision = torch.float16
        if not self.model:
            from modelstream import Model_Diffusers2

            print("Loading model...")
            self.model = Model_Diffusers2(
                base_model_id=model_name,
                task_name=task,
                vae_model=vae_model if vae_model != "None" else None,
                type_model_precision=model_precision,
                retain_task_model_in_cache=retain_task_cache_gui,
            )

        if task != "txt2img" and not image_control:
            raise ValueError(
                "No control image found: To use this function, "
                "you have to upload an image in 'Image ControlNet/Inpaint/Img2img'"
            )

        if task == "inpaint" and not image_mask:
            raise ValueError("No mask image found: Specify one in 'Image Mask'")

        if upscaler_model_path in [
            None,
            "Lanczos",
            "Nearest"
        ]:
            upscaler_model = upscaler_model_path
        else:
            directory_upscalers = 'upscalers'
            os.makedirs(
                directory_upscalers,
                exist_ok=True
            )

            url_upscaler = upscaler_dict_gui[upscaler_model_path]

            if not os.path.exists(f"./upscalers/{url_upscaler.split('/')[-1]}"):
                download_things(
                    directory_upscalers,
                    url_upscaler,
                    hf_token
                )

            upscaler_model = f"./upscalers/{url_upscaler.split('/')[-1]}"

        logging.getLogger("ultralytics").setLevel(logging.INFO if adetailer_verbose else logging.ERROR)

        print("Config model:", model_name, vae_model, loras_list)

        self.model.load_pipe(
            model_name,
            task_name=task,
            vae_model=vae_model if vae_model != "None" else None,
            type_model_precision=model_precision,
            retain_task_model_in_cache=retain_task_cache_gui,
        )

        if textual_inversion and self.model.class_name == "StableDiffusionXLPipeline":
            print("No Textual inversion for SDXL")

        adetailer_params_A: dict = {
            "face_detector_ad": face_detector_ad_a,
            "person_detector_ad": person_detector_ad_a,
            "hand_detector_ad": hand_detector_ad_a,
            "prompt": prompt_ad_a,
            "negative_prompt": negative_prompt_ad_a,
            "strength": strength_ad_a,
            # "image_list_task" : None,
            "mask_dilation": mask_dilation_a,
            "mask_blur": mask_blur_a,
            "mask_padding": mask_padding_a,
            "inpaint_only": adetailer_inpaint_only,
            "sampler": adetailer_sampler,
        }
        adetailer_params_B: dict = {
            "face_detector_ad": face_detector_ad_b,
            "person_detector_ad": person_detector_ad_b,
            "hand_detector_ad": hand_detector_ad_b,
            "prompt": prompt_ad_b,
            "negative_prompt": negative_prompt_ad_b,
            "strength": strength_ad_b,
            # "image_list_task" : None,
            "mask_dilation": mask_dilation_b,
            "mask_blur": mask_blur_b,
            "mask_padding": mask_padding_b,
        }
        pipe_params: dict = {
            "prompt": prompt,
            "negative_prompt": neg_prompt,
            "img_height": img_height,
            "img_width": img_width,
            "num_images": num_images,
            "num_steps": steps,
            "guidance_scale": cfg,
            "clip_skip": clip_skip,
            "seed": seed,
            "image": image_control,
            "preprocessor_name": preprocessor_name,
            "preprocess_resolution": preprocess_resolution,
            "image_resolution": image_resolution,
            "style_prompt": style_prompt if style_prompt else "",
            "style_json_file": "",
            "image_mask": image_mask,  # only for Inpaint
            "strength": strength,  # only for Inpaint or ...
            "low_threshold": low_threshold,
            "high_threshold": high_threshold,
            "value_threshold": value_threshold,
            "distance_threshold": distance_threshold,
            "lora_A": lora1 if lora1 != "None" else None,
            "lora_scale_A": lora_scale1,
            "lora_B": lora2 if lora2 != "None" else None,
            "lora_scale_B": lora_scale2,
            "lora_C": lora3 if lora3 != "None" else None,
            "lora_scale_C": lora_scale3,
            "lora_D": lora4 if lora4 != "None" else None,
            "lora_scale_D": lora_scale4,
            "lora_E": lora5 if lora5 != "None" else None,
            "lora_scale_E": lora_scale5,
            "textual_inversion": self.embed_list if textual_inversion and self.model.class_name != "StableDiffusionXLPipeline" else [],
            "syntax_weights": syntax_weights,  # "Classic"
            "sampler": sampler,
            "xformers_memory_efficient_attention": xformers_memory_efficient_attention,
            "gui_active": True,
            "loop_generation": loop_generation,
            "controlnet_conditioning_scale": float(controlnet_output_scaling_in_unet),
            "control_guidance_start": float(controlnet_start_threshold),
            "control_guidance_end": float(controlnet_stop_threshold),
            "generator_in_cpu": generator_in_cpu,
            "FreeU": freeu,
            "adetailer_A": adetailer_active_a,
            "adetailer_A_params": adetailer_params_A,
            "adetailer_B": adetailer_active_b,
            "adetailer_B_params": adetailer_params_B,
            "leave_progress_bar": leave_progress_bar,
            "disable_progress_bar": disable_progress_bar,
            "image_previews": image_previews,
            "display_images": display_images,
            "save_generated_images": save_generated_images,
            "image_storage_location": image_storage_location,
            "retain_compel_previous_load": retain_compel_previous_load,
            "retain_detailfix_model_previous_load": retain_detailfix_model_previous_load,
            "retain_hires_model_previous_load": retain_hires_model_previous_load,
            "t2i_adapter_preprocessor": t2i_adapter_preprocessor,
            "t2i_adapter_conditioning_scale": float(t2i_adapter_conditioning_scale),
            "t2i_adapter_conditioning_factor": float(t2i_adapter_conditioning_factor),
            "upscaler_model_path": upscaler_model,
            "upscaler_increases_size": upscaler_increases_size,
            "esrgan_tile": esrgan_tile,
            "esrgan_tile_overlap": esrgan_tile_overlap,
            "hires_steps": hires_steps,
            "hires_denoising_strength": hires_denoising_strength,
            "hires_prompt": hires_prompt,
            "hires_negative_prompt": hires_negative_prompt,
            "hires_sampler": hires_sampler,
            "hires_before_adetailer": hires_before_adetailer,
            "hires_after_adetailer": hires_after_adetailer,
            "ip_adapter_image": params_ip_img,
            "ip_adapter_mask": params_ip_msk,
            "ip_adapter_model": params_ip_model,
            "ip_adapter_mode": params_ip_mode,
            "ip_adapter_scale": params_ip_scale,
        }

        random_number: int = random.randint(1, 100)
        if random_number < 25 and num_images < 3:
            if (not upscaler_model and
                    steps < 45 and
                    task in ["txt2img", "img2img"] and
                    not adetailer_active_a and
                    not adetailer_active_b):
                num_images *= 2
                pipe_params["num_images"] = num_images
                gr.Info("Num images x 2 🎉")

        # Maybe fix lora issue: 'Cannot copy out of meta tensor; no data!''
        self.model.pipe.to("cuda:0" if torch.cuda.is_available() else "cpu")

        info_state = f"PROCESSING"
        for img, seed, data in self.model(**pipe_params):
            info_state += "."
            if data:
                info_state = f"COMPLETED. Seeds: {str(seed)}"
                if vae_msg:
                    info_state = info_state + "<br>" + vae_msg
                if msg_lora:
                    info_state = info_state + "<br>" + "<br>".join(msg_lora)
            yield img, info_state