import numpy as np
from PIL import Image, ImageDraw, ImageFont
import cv2
from sklearn.decomposition import PCA
from torchvision import transforms
import matplotlib.pyplot as plt
import torch

import os


def display_attention_maps(
    attention_maps,
    is_cross,
    num_heads,
    tokenizer,
    prompts,
    dir_name,
    step,
    layer,
    resolution,
    is_query=False,
    is_key=False,
    points=None,
    image_path=None,
):
    attention_maps = attention_maps.reshape(-1, num_heads, attention_maps.size(-2), attention_maps.size(-1))
    num_samples = len(attention_maps) // 2
    attention_type = 'cross' if is_cross else 'self'
    for i, attention_map in enumerate(attention_maps):
        if is_query:
            attention_type = f'{attention_type}_queries'
        elif is_key:
            attention_type = f'{attention_type}_keys'

        cond = 'uncond' if i < num_samples else 'cond'
        i = i % num_samples
        cur_dir_name = f'{dir_name}/{resolution}/{attention_type}/{layer}/{cond}/{i}'
        os.makedirs(cur_dir_name, exist_ok=True)

        if is_cross and not is_query:
            fig = show_cross_attention(attention_map, tokenizer, prompts[i % num_samples])
        else:
            fig = show_self_attention(attention_map)
            if points is not None:
                point_dir_name = f'{cur_dir_name}/points'
                os.makedirs(point_dir_name, exist_ok=True)
                for j, point in enumerate(points):
                    specific_point_dir_name = f'{point_dir_name}/{j}'
                    os.makedirs(specific_point_dir_name, exist_ok=True)
                    point_path = f'{specific_point_dir_name}/{step}.png'
                    point_fig = show_individual_self_attention(attention_map, point, image_path=image_path)
                    point_fig.save(point_path)
                    point_fig.close()

        fig.save(f'{cur_dir_name}/{step}.png')
        fig.close()


def text_under_image(image: np.ndarray, text: str, text_color: tuple[int, int, int] = (0, 0, 0)):
    h, w, c = image.shape
    offset = int(h * .2)
    font = cv2.FONT_HERSHEY_SIMPLEX
    # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
    text_size = cv2.getTextSize(text, font, 1, 2)[0]
    lines = text.splitlines()
    img = np.ones((h + offset + (text_size[1] + 2) * len(lines) - 2, w, c), dtype=np.uint8) * 255
    img[:h, :w] = image

    for i, line in enumerate(lines):
        text_size = cv2.getTextSize(line, font, 1, 2)[0]
        text_x, text_y = ((w - text_size[0]) // 2, h + offset + i * (text_size[1] + 2))
        cv2.putText(img, line, (text_x, text_y), font, 1, text_color, 2)

    return img

def view_images(images, num_rows=1, offset_ratio=0.02):
    if type(images) is list:
        num_empty = len(images) % num_rows
    elif images.ndim == 4:
        num_empty = images.shape[0] % num_rows
    else:
        images = [images]
        num_empty = 0

    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
    num_items = len(images)

    h, w, c = images[0].shape
    offset = int(h * offset_ratio)
    num_cols = num_items // num_rows
    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
    for i in range(num_rows):
        for j in range(num_cols):
            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
                i * num_cols + j]

    return Image.fromarray(image_)


def show_cross_attention(attention_maps, tokenizer, prompt, k_norms=None, v_norms=None):
    attention_maps = attention_maps.mean(dim=0)
    res = int(attention_maps.size(-2) ** 0.5)
    attention_maps = attention_maps.reshape(res, res, -1)
    tokens = tokenizer.encode(prompt)
    decoder = tokenizer.decode
    if k_norms is not None:
        k_norms = k_norms.round(decimals=1)
    if v_norms is not None:
        v_norms = v_norms.round(decimals=1)
    images = []
    for i in range(len(tokens) + 5):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.detach().cpu().numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        token = tokens[i] if i < len(tokens) else tokens[-1]
        text = decoder(int(token))
        if k_norms is not None and v_norms is not None:
            text += f'\n{k_norms[i]}\n{v_norms[i]})'
        image = text_under_image(image, text)
        images.append(image)
    return view_images(np.stack(images, axis=0))


def show_queries_keys(queries, keys, colors, labels):  # [h ni d]
    num_queries = [query.size(1) for query in queries]
    num_keys = [key.size(1) for key in keys]
    h, _, d = queries[0].shape

    data = torch.cat((*queries, *keys), dim=1)  # h n d
    data = data.permute(1, 0, 2)  # n h d
    data = data.reshape(-1, h * d).detach().cpu().numpy()
    pca = PCA(n_components=2)
    data = pca.fit_transform(data)  # n 2

    query_indices = np.array(num_queries).cumsum()
    total_num_queries = query_indices[-1]
    queries = np.split(data[:total_num_queries], query_indices[:-1])
    if len(num_keys) == 0:
        keys = [None, ] * len(labels)
    else:
        key_indices = np.array(num_keys).cumsum()
        keys = np.split(data[total_num_queries:], key_indices[:-1])

    fig, ax = plt.subplots()
    marker_size = plt.rcParams['lines.markersize'] ** 2
    query_size = int(1.25 * marker_size)
    key_size = int(2 * marker_size)
    for query, key, color, label in zip(queries, keys, colors, labels):
        print(f'# queries of {label}', query.shape[0])
        ax.scatter(query[:, 0], query[:, 1], s=query_size, color=color, marker='o', label=f'"{label}" queries')

        if key is None:
            continue

        print(f'# keys of {label}', key.shape[0])
        keys_label = f'"{label}" key'
        if key.shape[0] > 1:
            keys_label += 's'
        ax.scatter(key[:, 0], key[:, 1], s=key_size, color=color, marker='x', label=keys_label)

    ax.set_axis_off()
    #ax.set_xlabel('X-axis')
    #ax.set_ylabel('Y-axis')
    #ax.set_title('Scatter Plot with Circles and Crosses')

    #ax.legend()
    return fig


def show_self_attention(attention_maps):  # h n m
    attention_maps = attention_maps.transpose(0, 1).flatten(start_dim=1).detach().cpu().numpy()
    pca = PCA(n_components=3)
    pca_img = pca.fit_transform(attention_maps)  # N X 3
    h = w = int(pca_img.shape[0] ** 0.5)
    pca_img = pca_img.reshape(h, w, 3)
    pca_img_min = pca_img.min(axis=(0, 1))
    pca_img_max = pca_img.max(axis=(0, 1))
    pca_img = (pca_img - pca_img_min) / (pca_img_max - pca_img_min)
    pca_img = Image.fromarray((pca_img * 255).astype(np.uint8))
    pca_img = transforms.Resize(256, interpolation=transforms.InterpolationMode.NEAREST)(pca_img)
    return pca_img


def draw_box(pil_img, bboxes, colors=None, width=5):
    draw = ImageDraw.Draw(pil_img)
    #font = ImageFont.truetype('./FreeMono.ttf', 25)
    w, h = pil_img.size
    colors = ['red'] * len(bboxes) if colors is None else colors
    for obj_bbox, color in zip(bboxes, colors):
        x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
        draw.rectangle([int(x_0 * w), int(y_0 * h), int(x_1 * w), int(y_1 * h)], outline=color, width=width)
    return pil_img


def show_individual_self_attention(attn, point, image_path=None):
    resolution = int(attn.size(-1) ** 0.5)
    attn = attn.mean(dim=0).reshape(resolution, resolution, resolution, resolution)
    attn = attn[round(point[1] * resolution), round(point[0] * resolution)]
    attn = (attn - attn.min()) / (attn.max() - attn.min())
    image = None if image_path is None else Image.open(image_path).convert('RGB')
    image = show_image_relevance(attn, image=image)
    return Image.fromarray(image)


def show_image_relevance(image_relevance, image: Image.Image = None, relevnace_res=16):
    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        img = img.resize((relevnace_res ** 2, relevnace_res ** 2))
        img = np.array(img)
        img = (img - img.min()) / (img.max() - img.min())
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
    image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
    image_relevance = image_relevance.cpu() # send it back to cpu
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
    vis = image_relevance if image is None else show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis