from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from diffusers import StableDiffusionImg2ImgPipeline

from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util.commons import disable_safety_checker, download_image
from internals.util.config import get_hf_token, num_return_sequences


class AbstractPipeline:
    def load(self, model_dir: str):
        pass

    def create(self, pipe):
        pass


class Text2Img(AbstractPipeline):
    @dataclass
    class Params:
        prompt: List[str] = None
        modified_prompt: List[str] = None
        prompt_left: List[str] = None
        prompt_right: List[str] = None

    def load(self, model_dir: str):
        self.pipe = two_step_pipeline.from_pretrained(
            model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
        ).to("cuda")
        self.__patch()

    def is_loaded(self):
        if hasattr(self, "pipe"):
            return True
        return False

    def create(self, pipeline: AbstractPipeline):
        self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
        self.__patch()

    def __patch(self):
        self.pipe.enable_xformers_memory_efficient_attention()

    @torch.inference_mode()
    def process(
        self,
        params: Params,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[str] = None,
        num_images_per_prompt: int = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        iteration: float = 3.0,
    ):
        prompt = params.prompt

        if params.prompt_left and params.prompt_right:
            # multi-character pipelines
            prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
            result = self.pipe.multi_character_diffusion(
                prompt=prompt,
                pos=["1:1-0:0", "1:2-0:0", "1:2-0:1"],
                mix_val=[0.2, 0.8, 0.8],
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                negative_prompt=[negative_prompt or ""] * len(prompt),
                num_images_per_prompt=num_return_sequences,
                eta=eta,
                # generator=generator,
                output_type=output_type,
                return_dict=return_dict,
                callback=callback,
                callback_steps=callback_steps,
            )
        else:
            # two step pipeline
            modified_prompt = params.modified_prompt

            result = self.pipe.two_step_pipeline(
                prompt=prompt,
                modified_prompts=modified_prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                negative_prompt=[negative_prompt or ""] * num_return_sequences,
                num_images_per_prompt=num_images_per_prompt,
                eta=eta,
                generator=generator,
                latents=latents,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
                output_type=output_type,
                return_dict=return_dict,
                callback=callback,
                callback_steps=callback_steps,
                cross_attention_kwargs=cross_attention_kwargs,
                iteration=iteration,
            )

        return Result.from_result(result)


class Img2Img(AbstractPipeline):
    __loaded = False

    def load(self, model_dir: str):
        if self.__loaded:
            return

        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token()
        ).to("cuda")
        self.__patch()

        self.__loaded = True

    def create(self, pipeline: AbstractPipeline):
        self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
            "cuda"
        )
        self.__patch()

        self.__loaded = True

    def __patch(self):
        self.pipe.enable_xformers_memory_efficient_attention()

    @torch.inference_mode()
    def process(
        self,
        prompt: List[str],
        imageUrl: str,
        negative_prompt: List[str],
        strength: float,
        guidance_scale: float,
        steps: int,
        width: int,
        height: int,
    ):
        image = download_image(imageUrl).resize((width, height))

        result = self.pipe.__call__(
            prompt=prompt,
            image=image,
            strength=strength,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_images_per_prompt=1,
            num_inference_steps=steps,
        )
        return Result.from_result(result)