import os
# import matplotlib
# matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
import gradio as gr
import cv2
import numpy as np
import torch
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import ImageDraw,Image
from utils.tools import box_prompt, format_results, point_prompt
from utils.tools_gradio import fast_process

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model
sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt"
model_type = "vit_t"
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam = mobile_sam.to(device=device)
mobile_sam.eval()

mask_generator = SamAutomaticMaskGenerator(mobile_sam)
predictor = SamPredictor(mobile_sam)

# default_example = examples[0]

@torch.no_grad()
def segment_with_boxs(
        image,
        input_size=1024,
        better_quality=False,
        withContours=True,
        use_retina=True,
        mask_random_color=True,
):
    global global_points
    global global_point_label

    input_size = int(input_size)
    w, h = image.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)

    image = image.resize((new_w, new_h))
    #################
    scaled_points = np.array(
        [[int(x * scale) for x in point] for point in global_points]
    )
    print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points)
    scaled_point_label = np.array(global_point_label)

    nd_image = np.array(image)
    print("mmmmmmm0mmmm",nd_image.shape)  #(685, 1024, 3)
    predictor.set_image(nd_image)  #改变形状
    masks, scores, logits = predictor.predict(
        point_coords=scaled_points,
        point_labels=scaled_point_label,
        multimask_output=True,
    )

    results = format_results(masks, scores, logits, 0)
    print("mmmmmmmmmmmmmmmm2222m",len(results)) # [530 437]
    annotations, _ = point_prompt(
        results, scaled_points, scaled_point_label, new_h, new_w
    )
    annotations = np.array([annotations])
    # 显示图像
    plt.imshow(annotations[0], cmap='viridis')  # 使用 'viridis' 颜色映射
    plt.colorbar()  # 显示颜色条
    plt.savefig(r'F:\zht\code\2.png')
    plt.show()

    fig = fast_process(
        annotations=annotations,
        image=image,
        device=device,
        scale=(1024 // input_size),
        better_quality=better_quality,
        mask_random_color=mask_random_color,
        bbox=None,
        use_retina=use_retina,
        withContours=withContours,
    )
    global_points = []
    global_point_label = []
    return fig, image

#################################################
if __name__ == "__main__":
        path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg"
        image1 = Image.open(path)
        # image = cv2.imread(path)
        print(image1.size)
        # global_points = [[1069,928]]
        global_points = [[324,740,1448,1192]]
        global_point_label = [1]
        segment_with_boxs(
                image1,
                input_size=1024,
                better_quality=False,
                withContours=True,
                use_retina=True,
                mask_random_color=True,
        )