import os
from typing import Dict, List

import cv2
import numpy as np
import streamlit as st
import torch
import wget
from PIL import Image
from streamlit_drawable_canvas import st_canvas

from isegm.inference import clicker as ck
from isegm.inference import utils
from isegm.inference.predictors import BasePredictor, get_predictor

###################################
# Global scope objects.
###################################
URL_PREFIX = "https://huggingface.co/curt-park/interactive-segmentation/resolve/main"
MODELS = {"RITM": "ritm_coco_lvis_h18_itermask.pth"}
POS_COLOR, NEG_COLOR = "#3498DB", "#C70039"
CANVAS_HEIGHT, CANVAS_WIDTH = 600, 600
ERR_X, ERR_Y = 5.5, 1.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clicker = ck.Clicker()
predictor = None
image = None


###################################
# Functions.
###################################
@st.cache(allow_output_mutation=True)
def load_model(model_path: str, device: torch.device) -> BasePredictor:
    model = utils.load_is_model(model_path, device, cpu_dist_maps=True)
    predictor_params = {"brs_mode": "NoBRS"}
    predictor = get_predictor(model, device=device, **predictor_params)
    return predictor


def feed_clicks(
    clicker: ck.Clicker,
    clicks: List[Dict[str, float]],
    image_width: int,
    image_height: int,
) -> None:
    ratio_h, ratio_w = image_height / CANVAS_HEIGHT, image_width / CANVAS_WIDTH
    for click in clicks:
        x, y = (click["left"] + ERR_X) * ratio_w, (click["top"] + ERR_Y) * ratio_h
        x, y = min(image_width, max(0, x)), min(image_height, max(0, y))

        is_positive = click["stroke"] == POS_COLOR
        click = ck.Click(is_positive=is_positive, coords=(y, x))
        clicker.add_click(click)


def predict(image: Image, mask: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
    predictor.set_input_image(np.array(image))
    with st.spinner("Wait for prediction..."):
        pred = predictor.get_prediction(clicker, prev_mask=mask)
    pred = cv2.resize(
        pred,
        dsize=(CANVAS_HEIGHT, CANVAS_WIDTH),
        interpolation=cv2.INTER_CUBIC,
    )
    pred = np.where(pred > threshold, 1.0, 0)
    return pred


###################################
# Sidebar GUI
###################################
# Items in the sidebar.
model = st.sidebar.selectbox("Select a Method:", tuple(MODELS.keys()))
threshold = st.sidebar.slider("Threshold: ", 0.0, 1.0, 0.5)
marking_type = st.sidebar.radio("Click Type:", ("Positive", "Negative"))
image_path = st.sidebar.file_uploader("Background Image:", type=["png", "jpg", "jpeg"])
if image_path:
    image = Image.open(image_path).convert("RGB")

###################################
# Preparation
###################################
# Model.
with st.spinner("Wait for downloading a model..."):
    if not os.path.exists(MODELS[model]):
        _ = wget.download(f"{URL_PREFIX}/{MODELS[model]}")
# Predictor.
with st.spinner("Wait for loading a model..."):
    predictor = load_model(MODELS[model], device)

###################################
# GUI
###################################
# Create a canvas component.
st.title("Canvas:")
canvas_result = st_canvas(
    fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
    stroke_width=3,
    stroke_color=POS_COLOR if marking_type == "Positive" else NEG_COLOR,
    background_color="#eee",
    background_image=image,
    update_streamlit=True,
    drawing_mode="point",
    point_display_radius=3,
    key="canvas",
    width=CANVAS_WIDTH,
    height=CANVAS_HEIGHT,
)

###################################
# Prediction
###################################
# Check the user inputs ans execute predictions.
st.title("Prediction:")
if canvas_result.json_data and canvas_result.json_data["objects"] and image:
    image_width, image_height = image.size
    feed_clicks(clicker, canvas_result.json_data["objects"], image_width, image_height)

    # Run prediction.
    mask = torch.zeros((1, 1, image_height, image_width), device=device)
    pred = predict(image, mask, threshold)

    # Show the prediction result.
    st.image(pred, caption="")