File size: 6,293 Bytes
8a95b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
import json

def build_sam2(cfg, checkpoints):
    return build_sam2_video_predictor(cfg, checkpoints)


def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)



def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

# 给帧添加points提示
# ann_frame_idx: the frame index we interact with
# ann_obj_id: give a unique id to each object we interact with (it can be any integers)
def add_new_points(predictor, inference_state, ann_frame_idx, ann_obj_id, points, labels):
    _, out_obj_ids, out_mask_logits = predictor.add_new_points(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )
    return out_obj_ids, out_mask_logits

# 获取所有帧的分割结果
def all_frames_masks(predictor, inference_state):
    video_segments = {}  # video_segments contains the per-frame segmentation results
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
        video_segments[out_frame_idx] = {
            out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, out_obj_id in enumerate(out_obj_ids)
        }
    return video_segments

def resize_mask_to_img(masks, target_width, target_height):
    frame_mask = []
    origin_size = masks[0][1].shape # 1表示object id
    for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
        # 每个frame可能包含多个object对应的mask
        masks = list(objects_mask.values())
        if not masks: # masks为空,即当前frame不包含object
            frame_mask.append(np.ones(origin_size, dtype=bool))
        else: # 将当前frame包含的所有object的mask取并集
            union_mask = masks[0]
            for mask in masks[1:]:
                union_mask = np.logical_or(union_mask, mask)
            frame_mask.append(union_mask)
    resized_mask = []
    for mask in frame_mask:
        mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
        resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
        resized_mask.append(np.array(resized_mask_image) > 0)

    return resized_mask

def sava_mask(output_folder, mask):


    # 转换为Image对象
    binary_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255, 'L')  # 'L'代表灰度模式

    new_file_path = os.path.join(output_folder, "binary_mask.jpg")

    # 保存新的图片
    binary_image.save(new_file_path)
    print(f"sava mask to {new_file_path} .")

# 经过SAM2获取所有frames的分割结果
def get_masks_from_sam2(dataset_name, scene_name, img_shape, h, w, target_ind):
    # 加载模型
    sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt"
    model_cfg = "sam2_hiera_l.yaml"

    predictor = build_sam2(model_cfg, sam2_checkpoint)

    # 视频帧所在的路径
    video_dir = os.path.join("data", dataset_name, scene_name, "images_8")

    # 读取帧图片
    frame_names = [
        p for p in sorted(os.listdir(video_dir))
        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
    ]

    inference_state = predictor.init_state(video_path=video_dir)
    predictor.reset_state(inference_state)


    # 给一个帧添加points
    # 读取prompts.json
    json_dir = os.path.join("data", dataset_name, "prompts.json")
    with open(json_dir, 'r') as file:
        data = json.load(file)
    # 解析 prompts
    prompts = data[scene_name]
    points = np.array(prompts['points'], dtype=np.float32)
    labels = np.array(prompts['labels'], dtype=np.int32)



    out_obj_ids, out_mask_logits = add_new_points(predictor, inference_state, 0, 1, points, labels)

    # sam2获取所有帧的分割结果
    video_segments = all_frames_masks(predictor, inference_state)

    # 渲染处理后展示结果
    vis_frame_stride = 3
    plt.close("all")
    for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
        plt.figure(figsize=(6, 4))
        plt.title(f"frame {out_frame_idx}")
        plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
        for out_obj_id, out_mask in video_segments[out_frame_idx].items():
            show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
        if out_frame_idx == 0:
            # 显示点
            show_points(points, labels, plt.gca())


        plt.title(f"Frame {out_frame_idx}")
        plt.axis('off')  # 可选:关闭坐标轴
        plt.show()

    # 保存target_ind对应的view的SAM2输出mask作为ground truth mask,用于计算IoU和Acc
    mask_dir = os.path.join("data", dataset_name, "masks", scene_name)
    sava_mask(mask_dir, video_segments[target_ind][1])
    # 将 SAM2的mask resize成DUST3R要求的尺寸
    resize_mask = resize_mask_to_img(video_segments, w, h)
    return resize_mask


def array_to_tensor_masks(masks_list):
    # 将列表转换为一个大的 ndarray,形状为 (n, H, W)
    masks_array = np.stack(masks_list)

    # 将其 reshape 为 (n, H*W, 1)
    masks_array = masks_array.reshape(masks_array.shape[0], -1)

    # 转换为 bool 类型的 Tensor
    masks_tensor = torch.tensor(masks_array, dtype=torch.bool)
    return masks_tensor