import albumentations as A
import base64
import cv2
import gradio as gr
import inspect
import io
import numpy as np
import os

from dataclasses import dataclass
from loguru import logger
from copy import deepcopy
from functools import wraps
from PIL import Image, ImageDraw
from typing import get_type_hints, Optional
from pydantic_core._pydantic_core import ValidationError
# from mixpanel import Mixpanel

from utils import is_not_supported_transform



# MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN")
# mp = Mixpanel(MIXPANEL_TOKEN)

HEADER = f"""
<div align="center">
    <p>
        <img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;">
        <span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span>
    </p>
    <p style="margin-top: -15px;">
        <a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a>
        &nbsp;
        <a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a>
    </p>
</div>
"""

DEFAULT_TRANSFORM = "Rotate"
NO_OPERATION_TRANFORM = "NoOp"

DEFAULT_IMAGE_PATH = "images/doctor.webp"
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH))
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0]
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1]
DEFAULT_BOXES = [
    [265, 121, 326, 177],  # Mask
    [192, 169, 401, 395],  # Coverall
]

mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]]
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]]
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]]
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints

BASE64_DEFAULT_MASKS = [
    {
        "label": "Coverall",
        # light green color
        "color": (144, 238, 144),
        "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
    },
    {
        "label": "Mask",
        # light blue color
        "color": (173, 216, 230),
        "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
    },
]

# Get all the transforms from the albumentations library
transforms_map = {
    name: cls
    for name, cls in vars(A).items()
    if (
        inspect.isclass(cls)
        and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
        and not is_not_supported_transform(cls)
        and not name.endswith("3D")
    )
}
transforms_map.pop("DualTransform", None)
transforms_map.pop("ImageOnlyTransform", None)
transforms_map.pop("ReferenceBasedTransform", None)
transforms_map.pop("ToFloat", None)
transforms_map.pop("Normalize", None)
transforms_keys = list(sorted(transforms_map.keys()))


# Decode the masks
for mask in BASE64_DEFAULT_MASKS:
    mask["mask"] = np.array(
        Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")
    )


@dataclass
class RequestParams:
    user_ip: str
    transform_name: Optional[str]

def track_event(event_name, user_id="unknown", properties=None):
    if properties is None:
        properties = {}
    #mp.track(user_id, event_name, properties)
    logger.info(f"Event tracked: {event_name} - {properties}")


def get_params(request: gr.Request) -> RequestParams:
    """Parse input request parameters."""
    ip = request.client.host
    transform_name = request.query_params.get("transform", None)
    params = RequestParams(user_ip=ip, transform_name=transform_name)
    track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name})
    return params


def run_with_retry(compose):
    @wraps(compose)
    def wrapper(*args, **kwargs):
        processors = deepcopy(compose.processors)
        for _ in range(4):
            try:
                result = compose(*args, **kwargs)
                break
            except NotImplementedError as e:
                print(f"Caught NotImplementedError: {e}")
                if "bbox" in str(e):
                    kwargs.pop("bboxes", None)
                    kwargs.pop("category_id", None)
                    compose.processors.pop("bboxes")
                if "keypoint" in str(e):
                    kwargs.pop("keypoints", None)
                    compose.processors.pop("keypoints")
                if "mask" in str(e):
                    kwargs.pop("mask", None)
            except (ValueError, ValidationError) as e:
                raise gr.Error(str(e))
            except Exception as e:
                compose.processors = processors
                raise e
        compose.processors = processors
        return result

    return wrapper


def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray:
    """Draw boxes with PIL."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
    return np.array(pil_image)


def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
    """Draw keypoints with PIL."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    for keypoint in keypoints:
        x, y = keypoint
        draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
    return np.array(pil_image)


def get_rgb_mask(masks):
    """Get the RGB mask from the binary mask."""
    rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
    for data in masks:
        mask = data["mask"]
        rgb_mask[mask > 0] = np.array(data["color"])
    return rgb_mask


def draw_mask(image, mask):
    """Draw the mask on the image."""
    image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
    return image_with_mask


def draw_not_implemented_image(image: np.ndarray, annotation_type: str):
    """Draw the image with a text. In the middle."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    # align in the centerm, and make bigger font
    text = f'Transform NOT working with "{annotation_type.upper()}" annotations.'
    length = draw.textlength(text)
    draw.text(
        (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
        text,
        fill=(255, 0, 0),
        align="center",
    )
    return np.array(pil_image)


def get_formatted_signature(function_or_class, indentation=4):

    signature = inspect.signature(function_or_class)
    type_hints = get_type_hints(function_or_class)

    args = []
    for param in signature.parameters.values():
        if param.name == "p":
            str_param = "p=1.0,"
        elif param.default == inspect.Parameter.empty:
            if "height" in param.name or "width" in param.name:
                str_param = f"{param.name}=300,"
            else:
                str_param = f"{param.name}=,"
        else:
            if isinstance(param.default, str):
                str_param = f'{param.name}="{param.default}",'
            else:
                str_param = f"{param.name}={param.default},"

        annotation = type_hints.get(param.name, param.annotation)
        if isinstance(param.annotation, type):
            str_param += f"  # {param.annotation.__name__}"
        else:
            str_annotation = str(annotation).replace("typing.", "")
            str_param += f"  # {str_annotation}"
        str_param = "\n" + " " * indentation + str_param
        args.append(str_param)

    result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
    return result


def get_formatted_transform(transform_name):
    track_event("transform_selected", properties={"transform_name": transform_name})
    transform = transforms_map[transform_name]
    return f"A.{transform.__name__}{get_formatted_signature(transform)}"


def get_formatted_transform_docs(transform_name):
    transform = transforms_map[transform_name]
    return transform.__doc__.strip("\n")


def update_augmented_images(image, code):

    if "=," in code:
        raise gr.Error("You have to fill in parameters to apply transform! See 'Code' section!")

    try:
        augmentation = eval(code)
    except ValidationError as e:
        raise gr.Error(str(e))
    except Exception as e:
        logger.info(code)
        logger.error(e)
        raise e
        
    track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code})

    compose = A.Compose(
        [augmentation],
        bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
        keypoint_params=A.KeypointParams(format="xy"),
    )
    compose = run_with_retry(compose)  # to prevent NotImplementedError

    keypoints = DEFAULT_KEYPOINTS
    bboxes = DEFAULT_BOXES
    mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
    augmented = compose(
        image=image,
        mask=mask,
        keypoints=keypoints,
        bboxes=bboxes,
        category_id=range(len(bboxes)),
    )
    image = augmented["image"]
    mask = augmented.get("mask", None)
    bboxes = augmented.get("bboxes", None)
    keypoints = augmented.get("keypoints", None)

    # Draw the augmented images (or replace by placeholder if not implemented)
    if mask is not None:
        image_with_mask = draw_mask(image.copy(), mask)
    else:
        image_with_mask = draw_not_implemented_image(image.copy(), "mask")

    if bboxes is not None:
        image_with_bboxes = draw_boxes(image.copy(), bboxes)
    else:
        image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes")

    if keypoints is not None:
        image_with_keypoints = draw_keypoints(image.copy(), keypoints)
    else:
        image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints")

    return [
        (image_with_mask, "Mask"),
        (image_with_bboxes, "Boxes"),
        (image_with_keypoints, "Keypoints"),
    ]


def update_image_info(image):
    h, w = image.shape[:2]
    dtype = image.dtype
    max_, min_ = image.max(), image.min()
    return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"


def update_code_and_docs(select):
    code = get_formatted_transform(select)
    docs = get_formatted_transform_docs(select)
    return code, docs

def update_code_and_docs_on_start(url_params: gr.Request):
    params = get_params(url_params)
    if params.transform_name is not None and params.transform_name not in transforms_map:
        gr.Warning(f"Sorry, `{params.transform_name}` transform is not supported at the moment :(")
        transform_name = NO_OPERATION_TRANFORM
    elif params.transform_name in transforms_map:
        transform_name = params.transform_name
    else:
        transform_name = DEFAULT_TRANSFORM
    return gr.update(value=transform_name)

with gr.Blocks() as demo:
    gr.Markdown(HEADER)
    with gr.Row():
        with gr.Column():
            with gr.Group():
                # gr.Markdown(
                #     ("&nbsp;" * 4) + \
                #     "If a component is loading on start, please, try to refresh the page a few times. [Working on fix...]"
                # )
                select = gr.Dropdown(
                    label="Select a transformation",
                    choices=transforms_keys,
                    value=DEFAULT_TRANSFORM,
                    type="value",
                    interactive=True,
                )
                with gr.Accordion("Documentation (click to expand)", open=False):
                    docs = gr.TextArea(
                        get_formatted_transform_docs(DEFAULT_TRANSFORM),
                        show_label=False,
                        interactive=False,
                    )
                code = gr.Code(
                    label="Code",
                    language="python",
                    value=get_formatted_transform(DEFAULT_TRANSFORM),
                    interactive=True,
                    lines=5,
                )
            info = gr.TextArea(
                value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)",
                show_label=False,
                lines=1,
                max_lines=1,
            )
            button = gr.Button("Apply!")
        image = gr.Image(
            value=DEFAULT_IMAGE_PATH,
            type="numpy",
            height=500,
            width=300,
            sources=[],
        )
    with gr.Row():
        augmented_image = gr.Gallery(
            value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"),
            rows=1,
            columns=3,
            show_label=False,
        )
    select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs])
    button.click(
        fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image]
    )
    demo.load(
        update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False
    )

if __name__ == "__main__":
    demo.launch()