Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from sam2.build_sam import build_sam2_video_predictor | |
import json | |
def build_sam2(cfg, checkpoints): | |
return build_sam2_video_predictor(cfg, checkpoints) | |
def show_mask(mask, ax, obj_id=None, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
cmap = plt.get_cmap("tab10") | |
cmap_idx = 0 if obj_id is None else obj_id | |
color = np.array([*cmap(cmap_idx)[:3], 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=200): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
# 给帧添加points提示 | |
# ann_frame_idx: the frame index we interact with | |
# ann_obj_id: give a unique id to each object we interact with (it can be any integers) | |
def add_new_points(predictor, inference_state, ann_frame_idx, ann_obj_id, points, labels): | |
_, out_obj_ids, out_mask_logits = predictor.add_new_points( | |
inference_state=inference_state, | |
frame_idx=ann_frame_idx, | |
obj_id=ann_obj_id, | |
points=points, | |
labels=labels, | |
) | |
return out_obj_ids, out_mask_logits | |
# 获取所有帧的分割结果 | |
def all_frames_masks(predictor, inference_state): | |
video_segments = {} # video_segments contains the per-frame segmentation results | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): | |
video_segments[out_frame_idx] = { | |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
return video_segments | |
def resize_mask_to_img(masks, target_width, target_height): | |
frame_mask = [] | |
origin_size = masks[0][1].shape # 1表示object id | |
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果 | |
# 每个frame可能包含多个object对应的mask | |
masks = list(objects_mask.values()) | |
if not masks: # masks为空,即当前frame不包含object | |
frame_mask.append(np.ones(origin_size, dtype=bool)) | |
else: # 将当前frame包含的所有object的mask取并集 | |
union_mask = masks[0] | |
for mask in masks[1:]: | |
union_mask = np.logical_or(union_mask, mask) | |
frame_mask.append(union_mask) | |
resized_mask = [] | |
for mask in frame_mask: | |
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255) | |
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST) | |
resized_mask.append(np.array(resized_mask_image) > 0) | |
return resized_mask | |
def sava_mask(output_folder, mask): | |
# 转换为Image对象 | |
binary_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255, 'L') # 'L'代表灰度模式 | |
new_file_path = os.path.join(output_folder, "binary_mask.jpg") | |
# 保存新的图片 | |
binary_image.save(new_file_path) | |
print(f"sava mask to {new_file_path} .") | |
# 经过SAM2获取所有frames的分割结果 | |
def get_masks_from_sam2(dataset_name, scene_name, img_shape, h, w, target_ind): | |
# 加载模型 | |
sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt" | |
model_cfg = "sam2_hiera_l.yaml" | |
predictor = build_sam2(model_cfg, sam2_checkpoint) | |
# 视频帧所在的路径 | |
video_dir = os.path.join("data", dataset_name, scene_name, "images_8") | |
# 读取帧图片 | |
frame_names = [ | |
p for p in sorted(os.listdir(video_dir)) | |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"] | |
] | |
inference_state = predictor.init_state(video_path=video_dir) | |
predictor.reset_state(inference_state) | |
# 给一个帧添加points | |
# 读取prompts.json | |
json_dir = os.path.join("data", dataset_name, "prompts.json") | |
with open(json_dir, 'r') as file: | |
data = json.load(file) | |
# 解析 prompts | |
prompts = data[scene_name] | |
points = np.array(prompts['points'], dtype=np.float32) | |
labels = np.array(prompts['labels'], dtype=np.int32) | |
out_obj_ids, out_mask_logits = add_new_points(predictor, inference_state, 0, 1, points, labels) | |
# sam2获取所有帧的分割结果 | |
video_segments = all_frames_masks(predictor, inference_state) | |
# 渲染处理后展示结果 | |
vis_frame_stride = 3 | |
plt.close("all") | |
for out_frame_idx in range(0, len(frame_names), vis_frame_stride): | |
plt.figure(figsize=(6, 4)) | |
plt.title(f"frame {out_frame_idx}") | |
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))) | |
for out_obj_id, out_mask in video_segments[out_frame_idx].items(): | |
show_mask(out_mask, plt.gca(), obj_id=out_obj_id) | |
if out_frame_idx == 0: | |
# 显示点 | |
show_points(points, labels, plt.gca()) | |
plt.title(f"Frame {out_frame_idx}") | |
plt.axis('off') # 可选:关闭坐标轴 | |
plt.show() | |
# 保存target_ind对应的view的SAM2输出mask作为ground truth mask,用于计算IoU和Acc | |
mask_dir = os.path.join("data", dataset_name, "masks", scene_name) | |
sava_mask(mask_dir, video_segments[target_ind][1]) | |
# 将 SAM2的mask resize成DUST3R要求的尺寸 | |
resize_mask = resize_mask_to_img(video_segments, w, h) | |
return resize_mask | |
def array_to_tensor_masks(masks_list): | |
# 将列表转换为一个大的 ndarray,形状为 (n, H, W) | |
masks_array = np.stack(masks_list) | |
# 将其 reshape 为 (n, H*W, 1) | |
masks_array = masks_array.reshape(masks_array.shape[0], -1) | |
# 转换为 bool 类型的 Tensor | |
masks_tensor = torch.tensor(masks_array, dtype=torch.bool) | |
return masks_tensor |