import dataclasses
import os

import hydra
import numpy as np
import torch
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from omegaconf import OmegaConf
from safetensors.torch import load_model
from scipy.spatial.transform import Rotation

from point_sam import build_point_sam
import argparse

app = Flask(__name__, static_folder="static")
CORS(app)

MAX_POINT_ID = 100
point_info_id = 0
point_info_list = [None for _ in range(MAX_POINT_ID)]

@dataclasses.dataclass
class AuxInputs:
    coords: torch.Tensor
    features: torch.Tensor
    centers: torch.Tensor
    interp_index: torch.Tensor = None
    interp_weight: torch.Tensor = None

def repeat_interleave(x: torch.Tensor, repeats: int, dim: int):
    if repeats == 1:
        return x
    shape = list(x.shape)
    shape.insert(dim + 1, 1)
    shape[dim + 1] = repeats
    x = x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1)
    return x


class PointCloudProcessor:
    def __init__(self, device="cuda", batch=True, return_tensors="pt"):
        self.device = device
        self.batch = batch
        self.return_tensors = return_tensors

        self.center = None
        self.scale = None

    def __call__(self, xyz: np.ndarray, rgb: np.ndarray):
        # # The original data is z-up. Make it y-up.
        # rot = Rotation.from_euler("x", -90, degrees=True)
        # xyz = rot.apply(xyz)

        if self.center is None or self.scale is None:
            self.center = xyz.mean(0)
            self.scale = np.max(np.linalg.norm(xyz - self.center, axis=-1))

        xyz = (xyz - self.center) / self.scale
        rgb = ((rgb / 255.0) - 0.5) * 2

        if self.return_tensors == "np":
            coords = np.float32(xyz)
            feats = np.float32(rgb)
            if self.batch:
                coords = np.expand_dims(coords, 0)
                feats = np.expand_dims(feats, 0)
        elif self.return_tensors == "pt":
            coords = torch.tensor(xyz, dtype=torch.float32, device=self.device)
            feats = torch.tensor(rgb, dtype=torch.float32, device=self.device)
            if self.batch:
                coords = coords.unsqueeze(0)
                feats = feats.unsqueeze(0)
        else:
            raise ValueError(self.return_tensors)

        return coords, feats

    def normalize(self, xyz):
        return (xyz - self.center) / self.scale


class PointCloudSAMPredictor:
    input_xyz: np.ndarray
    input_rgb: np.ndarray
    prompt_coords: list[tuple[float, float, float]]
    prompt_labels: list[int]

    coords: torch.Tensor
    feats: torch.Tensor

    pc_embedding: torch.Tensor
    patches: dict[str, torch.Tensor]
    prompt_mask: torch.Tensor

    def __init__(self):
        print("Created model")
        model = build_point_sam("./model-2.safetensors")
        model.pc_encoder.patch_embed.grouper.num_groups = 1024
        model.pc_encoder.patch_embed.grouper.group_size = 128
        if torch.cuda.is_available():
            model = model.cuda()
        model.eval()

        self.model = model

        self.input_rgb = None
        self.input_xyz = None

        self.input_processor = None
        self.coords = None
        self.feats = None

        self.pc_embedding = None
        self.patches = None

        self.prompt_coords = None
        self.prompt_labels = None
        self.prompt_mask = None
        self.candidate_index = 0

    @torch.no_grad()
    def set_pointcloud(self, xyz, rgb):
        self.input_xyz = xyz
        self.input_rgb = rgb

        self.input_processor = PointCloudProcessor()
        coords, feats = self.input_processor(xyz, rgb)
        self.coords = coords
        self.feats = feats

        pc_embedding, patches = self.model.pc_encoder(self.coords, self.feats)
        self.pc_embedding = pc_embedding
        self.patches = patches
        self.prompt_mask = None

    def set_prompts(self, prompt_coords, prompt_labels):
        self.prompt_coords = prompt_coords
        self.prompt_labels = prompt_labels

    @torch.no_grad()
    def predict_mask(self):
        normalized_prompt_coords = self.input_processor.normalize(
            np.array(self.prompt_coords)
        )
        prompt_coords = torch.tensor(
            normalized_prompt_coords, dtype=torch.float32, device="cuda"
        )
        prompt_labels = torch.tensor(
            self.prompt_labels, dtype=torch.bool, device="cuda"
        )
        prompt_coords = prompt_coords.reshape(1, -1, 3)
        prompt_labels = prompt_labels.reshape(1, -1)

        multimask_output = prompt_coords.shape[1] == 1

        # [B * M, num_outputs, num_points], [B * M, num_outputs]
        def decode_masks(coords, feats, pc_embedding, patches, prompt_coords, prompt_labels, prompt_masks, multimask_output):
            pc_embeddings, patches = pc_embedding, patches
            centers = patches["centers"]
            knn_idx = patches["knn_idx"]
            coords = patches["coords"]
            feats = patches["feats"]
            aux_inputs = AuxInputs(coords=coords, features=feats, centers=centers)

            pc_pe = self.model.point_encoder.pe_layer(centers)
            sparse_embeddings = self.model.point_encoder(prompt_coords, prompt_labels)
            dense_embeddings = self.model.mask_encoder(prompt_masks, coords, centers, knn_idx)
            dense_embeddings = repeat_interleave(
                dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0
            )

            logits, iou_preds = self.model.mask_decoder(
                pc_embeddings,
                pc_pe,
                sparse_embeddings,
                dense_embeddings,
                aux_inputs=aux_inputs,
                multimask_output=multimask_output,
            )
            return logits, iou_preds

        logits, scores = decode_masks(
            self.coords,
            self.feats,
            self.pc_embedding,
            self.patches,
            prompt_coords,
            prompt_labels,
            self.prompt_mask[self.candidate_index].unsqueeze(0) if self.prompt_mask is not None else None,
            multimask_output,
        )
        logits = logits.squeeze(0)
        scores = scores.squeeze(0)

        # if multimask_output:
        #     index = scores.argmax(0).item()
        #     logit = logits[index]
        # else:
        #     logit = logits.squeeze(0)

        # self.prompt_mask = logit.unsqueeze(0)

        # pred_mask = logit > 0
        # return pred_mask.cpu().numpy()

        # Sort according to scores
        _, indices = scores.sort(descending=True)
        logits = logits[indices]

        self.prompt_mask = logits  # [num_outputs, num_points]
        self.candidate_index = 0

        return (logits > 0).cpu().numpy()

    def set_candidate(self, index):
        self.candidate_index = index


predictor = PointCloudSAMPredictor()


@app.route("/")
def index():
    return app.send_static_file("index.html")

@app.route("/assets/<path:path>")
def assets_route(path):
    print(path)
    return app.send_static_file(f"assets/{path}")


@app.route("/hello_world", methods=["GET"])
def hello_world():
    return "Hello, World!"


@app.route("/set_pointcloud", methods=["POST"])
def set_pointcloud():
    request_data = request.get_json()
    # print(request_data)
    # print(type(request_data["points"]))
    # print(type(request_data["colors"]))

    xyz = request_data["points"]
    xyz = np.array(xyz).reshape(-1, 3)
    rgb = request_data["colors"]
    rgb = np.array(list(rgb)).reshape(-1, 3)
    predictor.set_pointcloud(xyz, rgb)

    pc_embedding = predictor.pc_embedding.cpu()
    patches = {"centers": predictor.patches["centers"].cpu(), "knn_idx": predictor.patches["knn_idx"].cpu(), "coords": predictor.coords.cpu(), "feats": predictor.feats.cpu()}
    center = predictor.input_processor.center
    scale = predictor.input_processor.scale

    global point_info_id
    global point_info_list
    point_info_list[point_info_id] = {"pc_embedding": pc_embedding, "patches": patches, "center": center, "scale": scale, "prompt_mask": None}
    
    return_msg = {"user_id": point_info_id}
    point_info_id += 1
    return jsonify(return_msg)
 

@app.route("/set_candidate", methods=["POST"])
def set_candidate():
    request_data = request.get_json()
    candidate_index = request_data["index"]
    predictor.set_candidate(candidate_index)
    return "success"


def visualize_pcd_with_prompts(xyz, rgb, prompt_coords, prompt_labels):
    import trimesh

    pcd = trimesh.PointCloud(xyz, rgb)
    prompt_spheres = []
    for i, coord in enumerate(prompt_coords):
        sphere = trimesh.creation.icosphere()
        sphere.apply_scale(0.02)
        sphere.apply_translation(coord)
        sphere.visual.vertex_colors = [255, 0, 0] if prompt_labels[i] else [0, 255, 0]
        prompt_spheres.append(sphere)

    return trimesh.Scene([pcd] + prompt_spheres)


@app.route("/set_prompts", methods=["POST"])
def set_prompts():
    global point_info_list

    request_data = request.get_json()
    print(request_data.keys())

    # [n_prompts, 3]
    prompt_coords = request_data["prompt_coords"]
    # [n_prompts]. 0 for negative, 1 for positive
    prompt_labels = request_data["prompt_labels"]
    user_id = request_data["user_id"]
    print(user_id)
    point_info = point_info_list[user_id]
    predictor.pc_embedding = point_info["pc_embedding"].cuda()
    patches = point_info["patches"]
    predictor.patches = {"centers": patches["centers"].cuda(), "knn_idx": patches["knn_idx"].cuda(), "coords": patches["coords"].cuda(), "feats": patches["feats"].cuda()}
    predictor.input_processor.center = point_info["center"]
    predictor.input_processor.scale = point_info["scale"]
    if point_info["prompt_mask"] is not None:
        predictor.prompt_mask = point_info["prompt_mask"].cuda()
    else:
        predictor.prompt_mask = None
    # instance_id = request_data["instance_id"]  # int
    if len(prompt_coords) == 0:
        predictor.prompt_mask = None
        pred_mask = np.zeros([len(prompt_coords)], dtype=np.bool_)
        return jsonify({"mask": pred_mask.tolist()})

    predictor.set_prompts(prompt_coords, prompt_labels)
    pred_mask = predictor.predict_mask()
    point_info_list[user_id]["prompt_mask"] = predictor.prompt_mask.cpu()

    # # Visualize
    # xyz = predictor.coords.cpu().numpy()[0]
    # rgb = predictor.feats.cpu().numpy()[0] * 0.5 + 0.5
    # prompt_coords = predictor.input_processor.normalize(np.array(predictor.prompt_coords))
    # scene = visualize_pcd_with_prompts(xyz, rgb, prompt_coords, predictor.prompt_labels)
    # scene.show()

    return jsonify({"mask": pred_mask.tolist()})


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=7860)
    args = parser.parse_args()
    app.run(host=args.host, port=args.port, debug=True)