from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from diffusers import (
    AutoencoderKL,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionXLImg2ImgPipeline,
    StableDiffusionXLPipeline,
)

from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import get_hf_token, get_is_sdxl, num_return_sequences


class AbstractPipeline:
    def load(self, model_dir: str):
        pass

    def create(self, pipe):
        pass


class Text2Img(AbstractPipeline):
    @dataclass
    class Params:
        prompt: List[str] = None
        modified_prompt: List[str] = None
        prompt_left: List[str] = None
        prompt_right: List[str] = None

    def load(self, model_dir: str):
        if get_is_sdxl():
            vae = AutoencoderKL.from_pretrained(
                "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
            )
            pipe = StableDiffusionXLPipeline.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                use_auth_token=get_hf_token(),
                use_safetensors=True,
            )
            pipe.vae = vae
            pipe.to("cuda")
            self.pipe = pipe
        else:
            self.pipe = two_step_pipeline.from_pretrained(
                model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
            ).to("cuda")
        self.__patch()

    def is_loaded(self):
        if hasattr(self, "pipe"):
            return True
        return False

    def create(self, pipeline: AbstractPipeline):
        if get_is_sdxl():
            self.pipe = StableDiffusionXLPipeline(**pipeline.pipe.components).to("cuda")
        else:
            self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
        self.__patch()

    def __patch(self):
        if get_is_sdxl():
            self.pipe.enable_vae_tiling()
            self.pipe.enable_vae_slicing()
        self.pipe.enable_xformers_memory_efficient_attention()

    @torch.inference_mode()
    def process(
        self,
        params: Params,
        num_inference_steps: int,
        height: int,
        width: int,
        negative_prompt: str,
        iteration: float = 3.0,
        **kwargs,
    ):
        prompt = params.prompt

        if params.prompt_left and params.prompt_right:
            # multi-character pipelines
            prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
            kwargs = {
                "prompt": prompt,
                "pos": ["1:1-0:0", "1:2-0:0", "1:2-0:1"],
                "mix_val": [0.2, 0.8, 0.8],
                "height": height,
                "width": width,
                "num_inference_steps": num_inference_steps,
                "negative_prompt": [negative_prompt or ""] * len(prompt),
                **kwargs,
            }
            result = self.pipe.multi_character_diffusion(**kwargs)
        else:
            # two step pipeline
            modified_prompt = params.modified_prompt

            if get_is_sdxl():
                print("Warning: Two step pipeline is not supported on SDXL")
                kwargs = {
                    "prompt": modified_prompt,
                }
            else:
                kwargs = {
                    "prompt": prompt,
                    "modified_prompts": modified_prompt,
                    "iteration": iteration,
                }

            kwargs = {
                "height": height,
                "width": width,
                "negative_prompt": [negative_prompt or ""] * num_return_sequences,
                "num_inference_steps": num_inference_steps,
                **kwargs,
            }
            result = self.pipe.__call__(**kwargs)

        return Result.from_result(result)


class Img2Img(AbstractPipeline):
    __loaded = False

    def load(self, model_dir: str):
        if self.__loaded:
            return

        if get_is_sdxl():
            self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                use_auth_token=get_hf_token(),
                use_safetensors=True,
            ).to("cuda")
        else:
            self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
                model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
            ).to("cuda")
        self.__patch()

        self.__loaded = True

    def create(self, pipeline: AbstractPipeline):
        if get_is_sdxl():
            self.pipe = StableDiffusionXLImg2ImgPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        else:
            self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
                "cuda"
            )
        self.__patch()

        self.__loaded = True

    def __patch(self):
        if get_is_sdxl():
            self.pipe.enable_vae_tiling()
            self.pipe.enable_vae_slicing()
        self.pipe.enable_xformers_memory_efficient_attention()

    @torch.inference_mode()
    def process(
        self,
        prompt: List[str],
        imageUrl: str,
        negative_prompt: List[str],
        num_inference_steps: int,
        width: int,
        height: int,
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        **kwargs,
    ):
        image = download_image(imageUrl).resize((width, height))

        kwargs = {
            "prompt": prompt,
            "image": image,
            "strength": strength,
            "negative_prompt": negative_prompt,
            "guidance_scale": guidance_scale,
            "num_images_per_prompt": 1,
            "num_inference_steps": num_inference_steps,
            **kwargs,
        }
        result = self.pipe.__call__(**kwargs)
        return Result.from_result(result)