3dilize_anything / SAM2 /sam2 /sam2_to_dust3r.py
yansong1616's picture
Upload 59 files
8a95b97 verified
raw
history blame
6.29 kB
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
import json
def build_sam2(cfg, checkpoints):
return build_sam2_video_predictor(cfg, checkpoints)
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# 给帧添加points提示
# ann_frame_idx: the frame index we interact with
# ann_obj_id: give a unique id to each object we interact with (it can be any integers)
def add_new_points(predictor, inference_state, ann_frame_idx, ann_obj_id, points, labels):
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
return out_obj_ids, out_mask_logits
# 获取所有帧的分割结果
def all_frames_masks(predictor, inference_state):
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
def resize_mask_to_img(masks, target_width, target_height):
frame_mask = []
origin_size = masks[0][1].shape # 1表示object id
for frame, objects_mask in masks.items(): # 每个frame和该frame对应的分割结果
# 每个frame可能包含多个object对应的mask
masks = list(objects_mask.values())
if not masks: # masks为空,即当前frame不包含object
frame_mask.append(np.ones(origin_size, dtype=bool))
else: # 将当前frame包含的所有object的mask取并集
union_mask = masks[0]
for mask in masks[1:]:
union_mask = np.logical_or(union_mask, mask)
frame_mask.append(union_mask)
resized_mask = []
for mask in frame_mask:
mask_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255)
resized_mask_image = mask_image.resize((target_width, target_height), Image.NEAREST)
resized_mask.append(np.array(resized_mask_image) > 0)
return resized_mask
def sava_mask(output_folder, mask):
# 转换为Image对象
binary_image = Image.fromarray(mask.squeeze(0).astype(np.uint8) * 255, 'L') # 'L'代表灰度模式
new_file_path = os.path.join(output_folder, "binary_mask.jpg")
# 保存新的图片
binary_image.save(new_file_path)
print(f"sava mask to {new_file_path} .")
# 经过SAM2获取所有frames的分割结果
def get_masks_from_sam2(dataset_name, scene_name, img_shape, h, w, target_ind):
# 加载模型
sam2_checkpoint = "D:\XMU\mac\hujie\\3D\DUST3RwithSAM2\dust3rWithSam2\SAM2\checkpoints\sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
predictor = build_sam2(model_cfg, sam2_checkpoint)
# 视频帧所在的路径
video_dir = os.path.join("data", dataset_name, scene_name, "images_8")
# 读取帧图片
frame_names = [
p for p in sorted(os.listdir(video_dir))
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
]
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
# 给一个帧添加points
# 读取prompts.json
json_dir = os.path.join("data", dataset_name, "prompts.json")
with open(json_dir, 'r') as file:
data = json.load(file)
# 解析 prompts
prompts = data[scene_name]
points = np.array(prompts['points'], dtype=np.float32)
labels = np.array(prompts['labels'], dtype=np.int32)
out_obj_ids, out_mask_logits = add_new_points(predictor, inference_state, 0, 1, points, labels)
# sam2获取所有帧的分割结果
video_segments = all_frames_masks(predictor, inference_state)
# 渲染处理后展示结果
vis_frame_stride = 3
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
plt.figure(figsize=(6, 4))
plt.title(f"frame {out_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
if out_frame_idx == 0:
# 显示点
show_points(points, labels, plt.gca())
plt.title(f"Frame {out_frame_idx}")
plt.axis('off') # 可选:关闭坐标轴
plt.show()
# 保存target_ind对应的view的SAM2输出mask作为ground truth mask,用于计算IoU和Acc
mask_dir = os.path.join("data", dataset_name, "masks", scene_name)
sava_mask(mask_dir, video_segments[target_ind][1])
# 将 SAM2的mask resize成DUST3R要求的尺寸
resize_mask = resize_mask_to_img(video_segments, w, h)
return resize_mask
def array_to_tensor_masks(masks_list):
# 将列表转换为一个大的 ndarray,形状为 (n, H, W)
masks_array = np.stack(masks_list)
# 将其 reshape 为 (n, H*W, 1)
masks_array = masks_array.reshape(masks_array.shape[0], -1)
# 转换为 bool 类型的 Tensor
masks_tensor = torch.tensor(masks_array, dtype=torch.bool)
return masks_tensor