from typing import List, Optional

import torch

from internals.data.dataAccessor import update_db
from internals.data.task import Task, TaskType
from internals.pipelines.commons import Img2Img, Text2Img
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.img_classifier import ImageClassifier
from internals.pipelines.img_to_text import Image2Text
from internals.pipelines.inpainter import InPainter
from internals.pipelines.pose_detector import PoseDetector
from internals.pipelines.prompt_modifier import PromptModifier
from internals.pipelines.safety_checker import SafetyChecker
from internals.util.args import apply_style_args
from internals.util.avatar import Avatar
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda, clear_cuda_and_gc
from internals.util.commons import pickPoses, upload_image, upload_images
from internals.util.config import (
    num_return_sequences,
    set_configs_from_task,
    set_root_dir,
)
from internals.util.failure_hander import FailureHandler
from internals.util.lora_style import LoraStyle
from internals.util.slack import Slack

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

auto_mode = False

prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
pose_detector = PoseDetector()
inpainter = InPainter()
img2text = Image2Text()
img_classifier = ImageClassifier()
controlnet = ControlNet()
lora_style = LoraStyle()
text2img_pipe = Text2Img()
img2img_pipe = Img2Img()
safety_checker = SafetyChecker()
slack = Slack()
avatar = Avatar()


def get_patched_prompt(task: Task):
    def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
        for i in range(len(prompt)):
            prompt[i] = avatar.add_code_names(prompt[i])
            prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
            if additional:
                prompt[i] = additional + " " + prompt[i]

    prompt = task.get_prompt()

    if task.is_prompt_engineering():
        prompt = prompt_modifier.modify(prompt)
    else:
        prompt = [prompt] * num_return_sequences

    ori_prompt = [task.get_prompt()] * num_return_sequences

    class_name = None
    add_style_and_character(ori_prompt, class_name)
    add_style_and_character(prompt, class_name)

    print({"prompts": prompt})

    return (prompt, ori_prompt)


def get_patched_prompt_text2img(task: Task) -> Text2Img.Params:
    def add_style_and_character(prompt: str, prepend: str = ""):
        prompt = avatar.add_code_names(prompt)
        prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
        prompt = prepend + prompt
        return prompt

    if task.get_prompt_left() and task.get_prompt_right():
        # prepend = "2characters, "
        prepend = ""
        if task.is_prompt_engineering():
            mod_prompt = prompt_modifier.modify(task.get_prompt())
        else:
            mod_prompt = [task.get_prompt()] * num_return_sequences

        prompt, prompt_left, prompt_right = [], [], []
        for i in range(len(mod_prompt)):
            mp = mod_prompt[i].replace(task.get_prompt(), "")
            prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
            prompt_left.append(
                add_style_and_character(task.get_prompt_left(), prepend) + mp
            )
            prompt_right.append(
                add_style_and_character(task.get_prompt_right(), prepend) + mp
            )

        params = Text2Img.Params(
            prompt=prompt,
            prompt_left=prompt_left,
            prompt_right=prompt_right,
        )
    else:
        if task.is_prompt_engineering():
            mod_prompt = prompt_modifier.modify(task.get_prompt())
        else:
            mod_prompt = [task.get_prompt()] * num_return_sequences
        mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]

        params = Text2Img.Params(
            prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
            modified_prompt=mod_prompt,
        )

    print(params)

    return params


def get_patched_prompt_tile_upscale(task: Task):
    if task.get_prompt():
        prompt = task.get_prompt()
    else:
        prompt = img2text.process(task.get_imageUrl())

    prompt = avatar.add_code_names(prompt)
    prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())

    if not task.get_style():
        class_name = img_classifier.classify(
            task.get_imageUrl(), task.get_width(), task.get_height()
        )
    else:
        class_name = ""
    prompt = class_name + " " + prompt
    prompt = prompt.strip()

    print({"prompt": prompt})

    return prompt


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def canny(task: Task):
    prompt, _ = get_patched_prompt(task)

    controlnet.load_canny()

    # pipe2 is used for canny and pose
    lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
    lora_patcher.patch()

    images, has_nsfw = controlnet.process_canny(
        prompt=prompt,
        imageUrl=task.get_imageUrl(),
        seed=task.get_seed(),
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        guidance_scale=task.get_cy_guidance_scale(),
        negative_prompt=[
            f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
        ]
        * num_return_sequences,
        **lora_patcher.kwargs(),
    )

    generated_image_urls = upload_images(images, "_canny", task.get_taskId())

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def tile_upscale(task: Task):
    output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())

    prompt = get_patched_prompt_tile_upscale(task)

    controlnet.load_tile_upscaler()

    lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
    lora_patcher.patch()

    images, has_nsfw = controlnet.process_tile_upscaler(
        imageUrl=task.get_imageUrl(),
        seed=task.get_seed(),
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        prompt=prompt,
        resize_dimension=task.get_resize_dimension(),
        negative_prompt=task.get_negative_prompt(),
        guidance_scale=task.get_ti_guidance_scale(),
    )

    generated_image_url = upload_image(images[0], output_key)

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_url": generated_image_url,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def scribble(task: Task):
    prompt, _ = get_patched_prompt(task)

    controlnet.load_scribble()

    lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
    lora_patcher.patch()

    images, has_nsfw = controlnet.process_scribble(
        imageUrl=task.get_imageUrl(),
        seed=task.get_seed(),
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        prompt=prompt,
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
    )

    generated_image_urls = upload_images(images, "_scribble", task.get_taskId())

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def linearart(task: Task):
    prompt, _ = get_patched_prompt(task)

    controlnet.load_linearart()

    lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
    lora_patcher.patch()

    images, has_nsfw = controlnet.process_linearart(
        imageUrl=task.get_imageUrl(),
        seed=task.get_seed(),
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        prompt=prompt,
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
    )

    generated_image_urls = upload_images(images, "_linearart", task.get_taskId())

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
    prompt, _ = get_patched_prompt(task)

    controlnet.load_pose()

    # pipe2 is used for canny and pose
    lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
    lora_patcher.patch()

    if poses is None:
        if task.get_pose_coordinates():
            infered_pose = pose_detector.transform(
                image=task.get_imageUrl(),
                client_coordinates=task.get_pose_coordinates(),
                width=task.get_width(),
                height=task.get_height(),
            )
            poses = [infered_pose] * num_return_sequences
        else:
            poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences

    images, has_nsfw = controlnet.process_pose(
        prompt=prompt,
        image=poses,
        seed=task.get_seed(),
        steps=task.get_steps(),
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
        width=task.get_width(),
        height=task.get_height(),
        guidance_scale=task.get_po_guidance_scale(),
        **lora_patcher.kwargs(),
    )

    pose_output_key = "crecoAI/{}_pose.png".format(task.get_taskId())
    upload_image(poses[0], pose_output_key)

    generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())

    lora_patcher.cleanup()
    controlnet.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def text2img(task: Task):
    params = get_patched_prompt_text2img(task)

    lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
    lora_patcher.patch()

    torch.manual_seed(task.get_seed())

    images, has_nsfw = text2img_pipe.process(
        params=params,
        num_inference_steps=task.get_steps(),
        guidance_scale=7.5,
        height=task.get_height(),
        width=task.get_width(),
        negative_prompt=task.get_negative_prompt(),
        iteration=task.get_iteration(),
        **lora_patcher.kwargs(),
    )

    generated_image_urls = upload_images(images, "", task.get_taskId())

    lora_patcher.cleanup()

    return {
        **params.__dict__,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@auto_clear_cuda_and_gc(controlnet)
@slack.auto_send_alert
def img2img(task: Task):
    prompt, _ = get_patched_prompt(task)

    lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
    lora_patcher.patch()

    torch.manual_seed(task.get_seed())

    images, has_nsfw = img2img_pipe.process(
        prompt=prompt,
        imageUrl=task.get_imageUrl(),
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
        steps=task.get_steps(),
        width=task.get_width(),
        height=task.get_height(),
        strength=task.get_i2i_strength(),
        guidance_scale=task.get_i2i_guidance_scale(),
        **lora_patcher.kwargs(),
    )

    generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())

    lora_patcher.cleanup()

    return {
        "modified_prompts": prompt,
        "generated_image_urls": generated_image_urls,
        "has_nsfw": has_nsfw,
    }


@update_db
@slack.auto_send_alert
def inpaint(task: Task):
    prompt, _ = get_patched_prompt(task)

    print({"prompts": prompt})

    images = inpainter.process(
        prompt=prompt,
        image_url=task.get_imageUrl(),
        mask_image_url=task.get_maskImageUrl(),
        width=task.get_width(),
        height=task.get_height(),
        seed=task.get_seed(),
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
    )
    generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())

    clear_cuda()

    return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}


def model_fn(model_dir):
    print("Logs: model loaded .... starts")

    set_root_dir(__file__)

    FailureHandler.register()

    avatar.load_local(model_dir)

    prompt_modifier.load()
    pose_detector.load()
    img2text.load()
    img_classifier.load()

    lora_style.load(model_dir)
    safety_checker.load()

    controlnet.load(model_dir)
    text2img_pipe.load(model_dir)
    img2img_pipe.create(text2img_pipe)
    inpainter.create(text2img_pipe)

    safety_checker.apply(text2img_pipe)
    safety_checker.apply(img2img_pipe)
    safety_checker.apply(controlnet)

    print("Logs: model loaded ....")
    return


@FailureHandler.clear
def predict_fn(data, pipe):
    task = Task(data)
    print("task is ", data)

    FailureHandler.handle(task)

    try:
        # Set set_environment
        set_configs_from_task(task)

        # Apply arguments
        apply_style_args(data)

        # Re-fetch styles
        lora_style.fetch_styles()

        # Fetch avatars
        avatar.fetch_from_network(task.get_model_id())

        task_type = task.get_type()

        if task_type == TaskType.TEXT_TO_IMAGE:
            # character sheet
            if "character sheet" in task.get_prompt().lower():
                return pose(task, s3_outkey="", poses=pickPoses())
            else:
                return text2img(task)
        elif task_type == TaskType.IMAGE_TO_IMAGE:
            return img2img(task)
        elif task_type == TaskType.CANNY:
            return canny(task)
        elif task_type == TaskType.POSE:
            return pose(task)
        elif task_type == TaskType.TILE_UPSCALE:
            return tile_upscale(task)
        elif task_type == TaskType.INPAINT:
            return inpaint(task)
        elif task_type == TaskType.SCRIBBLE:
            return scribble(task)
        elif task_type == TaskType.LINEARART:
            return linearart(task)
        else:
            raise Exception("Invalid task type")
    except Exception as e:
        print(f"Error: {e}")
        slack.error_alert(task, e)
        controlnet.cleanup()
        return None