import matplotlib.pyplot as plt
# from demo import Demo, read_input_image,read_input_image_test
from evaluation.viz import plot_example_single
from dataset.torch import unbatch_to_device
import matplotlib.pyplot as plt
from typing import Optional, Tuple
import cv2
import torch
import numpy as np
import time
from logger import logger
from evaluation.run import resolve_checkpoint_path, pretrained_models
from models.maplocnet import MapLocNet
from models.voting import fuse_gps, argmax_xyr
# from data.image import resize_image, pad_image, rectify_image
from osm.raster import Canvas
from utils.wrappers import Camera
from utils.io import read_image
from utils.geo import BoundaryBox, Projection
from utils.exif import EXIF
import requests
from pathlib import Path
from utils.exif import EXIF
from dataset.image import resize_image, pad_image, rectify_image
# from maploc.demo import Demo, read_input_image
from dataset import UavMapDatasetModule
import torchvision.transforms as tvf
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from PIL import Image
# import pyproj
# Query OpenStreetMap for this area
from osm.tiling import TileManager
from utils.viz_localization import (
    likelihood_overlay,
    plot_dense_rotations,
    add_circle_inset,
)
# Show the inputs to the model: image and raster map
from osm.viz import Colormap, plot_nodes
from utils.viz_2d import plot_images

from utils.viz_2d import features_to_RGB
import random
from geopy.distance import geodesic


def vis_image_feature(F):
    def normalize(x):
        return x / np.linalg.norm(x, axis=-1, keepdims=True)

    # F=neural_map.numpy()
    F = F[:, 0:180, 0:180]
    flatten = []
    c, h, w = F.shape
    print(F.shape)
    F = np.rollaxis(F, 0, 3)
    F_flat = F.reshape(-1, c)
    flatten.append(F_flat)
    flatten = normalize(flatten)[0]

    flatten = np.nan_to_num(flatten, nan=0)
    pca = PCA(n_components=3)

    print(flatten.shape)
    flatten = pca.fit_transform(flatten)
    flatten = (normalize(flatten) + 1) / 2

    # h, w = F.shape[-2:]
    F_rgb, flatten = np.split(flatten, [h * w], axis=0)
    F_rgb = F_rgb.reshape((h, w, 3))
    return F_rgb
def distance(lat1, lon1, lat2, lon2):
    point1 = (lat1, lon1)
    point2 = (lat2, lon2)
    distance_km = geodesic(point1, point2).meters
    return distance_km

# # 示例
# lat1, lon1 = 39.9, 116.4  # 北京的经纬度
# lat2, lon2 = 31.2, 121.5  # 上海的经纬度

# distance_km = distance(lat1, lon1, lat2, lon2)
# print(distance_km)
def show_result(map_vis_image, pre_uv, pre_yaw):
    # 创建一个和原始图片大小相同的灰色蒙版图像
    gray_mask = np.zeros_like(map_vis_image)
    gray_mask.fill(128)  # 填充灰色

    # 将灰色蒙版图像与原始图像进行融合
    image = cv2.addWeighted(map_vis_image, 1, gray_mask, 0, 0)
    # 绘制真实值

    # 绘制预测值
    u, v = pre_uv
    x1, y1 = int(u), int(v)  # 替换为实际的起点坐标
    angle = pre_yaw - 90  # 替换为实际的箭头角度
    # 计算箭头的终点坐标
    length = 20
    x2 = int(x1 + length * np.cos(np.radians(angle)))
    y2 = int(y1 + length * np.sin(np.radians(angle)))
    # 在图像上画出箭头
    cv2.arrowedLine(image, (x1, y1), (x2, y2), (0, 0, 0), 2, 5, 0, 0.3)
    # cv2.circle(image, (x1, y1), radius=2, color=(255, 0, 255), thickness=-1)
    return image


def xyz_to_latlon(x, y, z):
    # 定义WGS84投影
    wgs84 = pyproj.CRS('EPSG:4326')

    # 定义XYZ投影
    xyz = pyproj.CRS(f'+proj=geocent +datum=WGS84 +units=m +no_defs')

    # 创建坐标转换器
    transformer = pyproj.Transformer.from_crs(xyz, wgs84)

    # 转换坐标
    lon, lat, _ = transformer.transform(x, y, z)

    return lat, lon


class Demo:
    def __init__(
            self,
            experiment_or_path: Optional[str] = "OrienterNet_MGL",
            device=None,
            **kwargs
    ):
        if experiment_or_path in pretrained_models:
            experiment_or_path, _ = pretrained_models[experiment_or_path]
        path = resolve_checkpoint_path(experiment_or_path)
        ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
        config = ckpt["hyper_parameters"]
        config.model.update(kwargs)
        config.model.image_encoder.backbone.pretrained = False

        model = MapLocNet(config.model).eval()
        state = {k[len("model."):]: v for k, v in ckpt["state_dict"].items()}
        model.load_state_dict(state, strict=True)
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)

        self.model = model
        self.config = config
        self.device = device

    def prepare_data(
            self,
            image: np.ndarray,
            camera: Camera,
            canvas: Canvas,
            roll_pitch: Optional[Tuple[float]] = None,
    ):

        image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)

        return {
            'map': torch.from_numpy(canvas.raster).long(),
            'image': image,
            # 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
            # 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
            # "uv":torch.tensor([float(u), float(v)]).float().unsqueeze(0),
        }
        # return dict(
        #     image=image,
        #     map=torch.from_numpy(canvas.raster).long(),
        #     camera=camera.float(),
        #     valid=valid,
        # )

    def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):

        data = self.prepare_data(image, camera, canvas, **kwargs)
        data_ = {k: v.to(self.device)[None] for k, v in data.items()}
        # data_np = {k: v.cpu().numpy()[None] for k, v in data.items()}
        # logger.info(data_)
        # np.save(data_np, 'data_.npy')
        start = time.time()
        with torch.no_grad():
            pred = self.model(data_)

        end = time.time()
        xy_gps = canvas.bbox.center
        uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))

        lp_xyr = pred["log_probs"].squeeze(0)
        # tile_size = canvas.bbox.size.min() / 2
        # sigma = tile_size - 20  # 20 meters margin
        # lp_xyr = fuse_gps(
        #     lp_xyr,
        #     uv_gps.to(lp_xyr),
        #     self.config.model.pixel_per_meter,
        #     sigma=sigma,
        # )
        xyr = argmax_xyr(lp_xyr).cpu()

        prob = lp_xyr.exp().cpu()
        neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
        print('total time:', start - end)
        return xyr[:2], xyr[2], prob, neural_map, data["image"], data_, pred


def load_test_data(
        root: Path,
        city: str,
        index: int,
):
    uav_image_path = root / city / 'uav'
    map_path = root / city / 'map'
    map_vis = root / city / 'map_vis'
    info_path = root / city / 'info.csv'
    osm_path = root / city / '{}.osm'.format(city)

    info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)

    id, uav_name, map_name, \
        uav_long, uav_lat, \
        map_long, map_lat, \
        tile_size_meters, pixel_per_meter, \
        u, v, yaw, dis = info[index]
    print(info[index])
    uav_image_rgb = cv2.imread(str(uav_image_path / uav_name))
    uav_image_rgb = cv2.cvtColor(uav_image_rgb, cv2.COLOR_BGR2RGB)

    # w,h,c=uav_image_rgb.shape
    # # 指定裁剪区域的坐标
    # x = w//2   # 起始横坐标
    # y = h//2  # 起始纵坐标
    # w = 150   # 宽度
    # h = 150   # 高度

    # # 裁剪图像
    # uav_image_rgb = uav_image_rgb[y-h:y+h, x-w:x+w]

    map_vis_image = cv2.imread(str(map_vis / uav_name))
    map_vis_image = cv2.cvtColor(map_vis_image, cv2.COLOR_BGR2RGB)

    map = np.load(str(map_path / map_name))

    tfs = []
    tfs.append(tvf.ToTensor())
    tfs.append(tvf.Resize(256))
    val_tfs = tvf.Compose(tfs)

    uav_image = val_tfs(uav_image_rgb)
    # print(id, uav_name, map_name, \
    #     uav_long, uav_lat, \
    #     map_long, map_lat, \
    #     tile_size_meters, pixel_per_meter, \
    #     u, v, yaw,dis)
    uav_path = str(uav_image_path / uav_name)
    return {
        'map': torch.from_numpy(np.ascontiguousarray(map)).long().unsqueeze(0),
        'image': torch.tensor(uav_image).unsqueeze(0),
        'roll_pitch_yaw': torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
        'pixels_per_meter': torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
        "uv": torch.tensor([float(u), float(v)]).float().unsqueeze(0),
    }, uav_image_rgb, map_vis_image, uav_path, [float(map_lat), float(map_long)]


def crop_image(image, width, height):
    # 计算剪裁区域的起始点坐标
    x = int((image.shape[1] - width) / 2)
    y = int((image.shape[0] - height) / 2)

    # 剪裁图像
    cropped_image = image[y:y + height, x:x + width]
    return cropped_image


def crop_square(image):
    # 获取图像的宽度和高度
    height, width = image.shape[:2]

    # 确定最小边的长度
    min_length = min(height, width)

    # 计算剪裁区域的坐标
    top = (height - min_length) // 2
    bottom = top + min_length
    left = (width - min_length) // 2
    right = left + min_length

    # 剪裁图像为正方形
    cropped_image = image[top:bottom, left:right]

    return cropped_image
def read_input_image_test(
        image,
        prior_latlon,
        tile_size_meters,
):
    # image = read_image(image_path)
    # # 剪裁图像
    # # 指定剪裁的宽度和高度
    # width = 1080*2
    # height =1080*2
    # image = crop_square(image)
    # # print("input image:",image.shape)
    # image = crop_image(image, width, height)
    # # print("crop_image:",image.shape)
    image = cv2.resize(image,(256,256))
    roll_pitch = None


    latlon = None
    if prior_latlon is not None:
        latlon = prior_latlon
        logger.info("Using prior latlon %s.", prior_latlon)

    if latlon is None:
        with open(image_path, "rb") as fid:
            exif = EXIF(fid, lambda: image.shape[:2])
        geo = exif.extract_geo()
        if geo:
            alt = geo.get("altitude", 0)  # read if available
            latlon = (geo["latitude"], geo["longitude"], alt)
            logger.info("Using prior location from EXIF.")
            # print(latlon)
        else:
            logger.info("Could not find any prior location in the image EXIF metadata.")

    latlon = np.array(latlon)

    proj = Projection(*latlon)
    center = proj.project(latlon)
    bbox = BoundaryBox(center, center) + float(tile_size_meters)
    camera=None
    image=cv2.resize(image,(256,256))
    return image, camera, roll_pitch, proj, bbox, latlon
if __name__ == '__main__':
    experiment_or_path = "weight/last-step-checkpointing.ckpt"
    # experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
    image_path='images/00000.jpg'
    prior_latlon=(37.75704325989902,-122.435941445631)
    tile_size_meters=128
    demo = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
    image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
        image_path,
        prior_latlon=prior_latlon,
        tile_size_meters=tile_size_meters,  # try 64, 256, etc.
    )
    tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1, tile_size=tile_size_meters)
    # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
    canvas = tiler.query(bbox)
    uv, yaw, prob, neural_map, image_rectified, data_, pred = demo.localize(
        image, camera, canvas)
    prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
    pass