Spaces:
Sleeping
Sleeping
File size: 6,293 Bytes
8a95b97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
import json
def build_sam2(cfg, checkpoints):
return build_sam2_video_predictor(cfg, checkpoints)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# 给帧添加points提示
# ann_frame_idx: the frame index we interact with
# ann_obj_id: give a unique id to each object we interact with (it can be any integers)
def add_new_points(predictor, inference_state, ann_frame_idx, ann_obj_id, points, labels):
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
return out_obj_ids, out_mask_logits
# 获取所有帧的分割结果
def all_frames_masks(predictor, inference_state):
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def resize_mask_to_img(masks, target_width, target_height):
frame_mask = []
origin_size = masks[0][1].shape # 1表示object id
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
# 每个frame可能包含多个object对应的mask
masks = list(objects_mask.values())
if not masks: # masks为空,即当前frame不包含object
frame_mask.append(np.ones(origin_size, dtype=bool))
else: # 将当前frame包含的所有object的mask取并集
union_mask = masks[0]
for mask in masks[1:]:
union_mask = np.logical_or(union_mask, mask)
frame_mask.append(union_mask)
resized_mask = []
for mask in frame_mask:
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
resized_mask.append(np.array(resized_mask_image) > 0)
return resized_mask
def sava_mask(output_folder, mask):
# 转换为Image对象
binary_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255, 'L') # 'L'代表灰度模式
new_file_path = os.path.join(output_folder, "binary_mask.jpg")
# 保存新的图片
binary_image.save(new_file_path)
print(f"sava mask to {new_file_path} .")
# 经过SAM2获取所有frames的分割结果
def get_masks_from_sam2(dataset_name, scene_name, img_shape, h, w, target_ind):
# 加载模型
sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2(model_cfg, sam2_checkpoint)
# 视频帧所在的路径
video_dir = os.path.join("data", dataset_name, scene_name, "images_8")
# 读取帧图片
frame_names = [
p for p in sorted(os.listdir(video_dir))
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
]
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
# 给一个帧添加points
# 读取prompts.json
json_dir = os.path.join("data", dataset_name, "prompts.json")
with open(json_dir, 'r') as file:
data = json.load(file)
# 解析 prompts
prompts = data[scene_name]
points = np.array(prompts['points'], dtype=np.float32)
labels = np.array(prompts['labels'], dtype=np.int32)
out_obj_ids, out_mask_logits = add_new_points(predictor, inference_state, 0, 1, points, labels)
# sam2获取所有帧的分割结果
video_segments = all_frames_masks(predictor, inference_state)
# 渲染处理后展示结果
vis_frame_stride = 3
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
plt.figure(figsize=(6, 4))
plt.title(f"frame {out_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
if out_frame_idx == 0:
# 显示点
show_points(points, labels, plt.gca())
plt.title(f"Frame {out_frame_idx}")
plt.axis('off') # 可选:关闭坐标轴
plt.show()
# 保存target_ind对应的view的SAM2输出mask作为ground truth mask,用于计算IoU和Acc
mask_dir = os.path.join("data", dataset_name, "masks", scene_name)
sava_mask(mask_dir, video_segments[target_ind][1])
# 将 SAM2的mask resize成DUST3R要求的尺寸
resize_mask = resize_mask_to_img(video_segments, w, h)
return resize_mask
def array_to_tensor_masks(masks_list):
# 将列表转换为一个大的 ndarray,形状为 (n, H, W)
masks_array = np.stack(masks_list)
# 将其 reshape 为 (n, H*W, 1)
masks_array = masks_array.reshape(masks_array.shape[0], -1)
# 转换为 bool 类型的 Tensor
masks_tensor = torch.tensor(masks_array, dtype=torch.bool)
return masks_tensor |