import numpy as np
import cv2
from ultralytics import FastSAM
import torch
import gc

# 定义可用的模型
MODELS = {
    "small": "./models/FastSAM-s.pt",
    "large": "./models/FastSAM-x.pt"
}

def clear_gpu_memory():
    """
    清理GPU显存
    """
    gc.collect()  # 清理Python的垃圾收集器
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # 清空PyTorch的CUDA缓存
        torch.cuda.ipc_collect()  # 收集CUDA IPC内存

def get_model(model_size: str = "large"):
    """
    获取指定大小的模型
    """
    if model_size not in MODELS:
        raise ValueError(f"Invalid model size. Available sizes: {list(MODELS.keys())}")
    
    try:
        return FastSAM(MODELS[model_size])
    except Exception as e:
        raise RuntimeError(f"Failed to load model: {str(e)}")

def mask_to_points(mask: np.ndarray) -> list:
    """
    Convert mask to a list of contour points
    
    Args:
        mask: 2D numpy array representing the mask
    
    Returns:
        list: Flattened list of points [x1, y1, x2, y2, ...]
    """
    # Convert mask to uint8 type
    mask_uint8 = mask.astype(np.uint8) * 255
    # Find contours
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if not contours:
        return []
    
    # Get the largest contour
    contour = max(contours, key=cv2.contourArea)
    # Convert contour points to flattened list
    points = []
    for point in contour:
        points.extend([float(point[0][0]), float(point[0][1])])
    return points

def segment_image_with_prompt(
    image: np.ndarray,
    model_size: str = "large",
    conf: float = 0.4,
    iou: float = 0.9,
    bboxes: list = None,
    points: list = None,
    labels: list = None,
    texts: str = None
):
    """
    带提示的图像分割函数
    
    Args:
        image: 输入图像 (numpy.ndarray)
        model_size: 模型大小 ("small" 或 "large")
        conf: 置信度阈值
        iou: IoU 阈值
        bboxes: 边界框列表 [x1, y1, x2, y2, ...]
        points: 点列表 [[x1, y1], [x2, y2], ...]
        labels: 点对应的标签列表 [0, 1, ...]
        texts: 文本提示
    """
    try:
        if image is None:
            raise ValueError("Invalid image input")
        
        # 获取模型并执行分割
        model = get_model(model_size)
        
        # 准备模型参数
        model_args = {
            "device": "cpu",
            "retina_masks": True,
            "conf": conf,
            "iou": iou
        }
        
        # 添加提示参数
        if bboxes is not None:
            model_args["bboxes"] = bboxes
        if points is not None and labels is not None:
            model_args["points"] = points
            model_args["labels"] = labels
        if texts is not None:
            model_args["texts"] = texts
            
        # 执行分割
        everything_results = model(image, **model_args)
        
        # 处理分割结果
        segments = []
        if everything_results and len(everything_results) > 0:
            result = everything_results[0]
            if hasattr(result, 'masks') and result.masks is not None:
                masks = result.masks.data.cpu().numpy()
                
                for mask in masks:
                    points = mask_to_points(mask)
                    if points:
                        segments.append(points)
        
        # 清理模型和GPU内存
        del model
        del everything_results
        if hasattr(result, 'masks'):
            del result.masks
        del result
        # clear_gpu_memory()
        
        return {
            "total_segments": len(segments),
            "segments": segments
        }
    except Exception as e:
        # 确保发生错误时也清理内存
        clear_gpu_memory()
        raise RuntimeError(f"Error processing image: {str(e)}")