Expedit-SAM / scripts /benchmark.py
kxqt's picture
init commit
2c50deb
raw
history blame
10.6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import cv2 # type: ignore
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from segment_anything.utils.amg import (
batch_iterator,
generate_crop_boxes,
)
import argparse
import json
import os
from typing import Any, Dict, List
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
from tqdm import tqdm
parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
"and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
"as well as pycocotools if saving in RLE format."
)
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help=(
"Path to the directory where masks will be output. Output will be either a folder "
"of PNGs per image or a single json with COCO-style masks."
),
)
parser.add_argument(
"--model-type",
type=str,
default="default",
help="The type of model to load, in ['default', 'vit_l', 'vit_b']",
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="The path to the SAM checkpoint to use for mask generation.",
)
parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument(
"--convert-to-rle",
action="store_true",
help=(
"Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
"Requires pycocotools."
),
)
amg_settings = parser.add_argument_group("AMG Settings")
amg_settings.add_argument(
"--points-per-side",
type=int,
default=None,
help="Generate masks by sampling a grid over the image with this many points to a side.",
)
amg_settings.add_argument(
"--points-per-batch",
type=int,
default=None,
help="How many input points to process simultaneously in one batch.",
)
amg_settings.add_argument(
"--pred-iou-thresh",
type=float,
default=None,
help="Exclude masks with a predicted score from the model that is lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-thresh",
type=float,
default=None,
help="Exclude masks with a stability score lower than this threshold.",
)
amg_settings.add_argument(
"--stability-score-offset",
type=float,
default=None,
help="Larger values perturb the mask more when measuring stability score.",
)
amg_settings.add_argument(
"--box-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding a duplicate mask.",
)
amg_settings.add_argument(
"--crop-n-layers",
type=int,
default=None,
help=(
"If >0, mask generation is run on smaller crops of the image to generate more masks. "
"The value sets how many different scales to crop at."
),
)
amg_settings.add_argument(
"--crop-nms-thresh",
type=float,
default=None,
help="The overlap threshold for excluding duplicate masks across different crops.",
)
amg_settings.add_argument(
"--crop-overlap-ratio",
type=int,
default=None,
help="Larger numbers mean image crops will overlap more.",
)
amg_settings.add_argument(
"--crop-n-points-downscale-factor",
type=int,
default=None,
help="The number of points-per-side in each layer of crop is reduced by this factor.",
)
amg_settings.add_argument(
"--min-mask-region-area",
type=int,
default=None,
help=(
"Disconnected mask regions or holes with area smaller than this value "
"in pixels are removed by postprocessing."
),
)
# add hourglass settings
amg_settings.add_argument(
"--use_hourglass",
action="store_true",
help="Use hourglass method to expedite mask generation.",
)
amg_settings.add_argument(
"--hourglass_clustering_location",
type=int,
default=6,
help="location of clustering, ranging from [0, num of layers of transformer)"
)
amg_settings.add_argument(
"--hourglass_num_cluster",
type=int,
default=100,
help="num of clusters, no more than total number of features"
)
amg_settings.add_argument(
"--hourglass_cluster_iters",
type=int,
default=5,
help="num of iterations in clustering"
)
amg_settings.add_argument(
"--hourglass_temperture",
type=float,
default=5e-3,
help="temperture in clustering and reconstruction"
)
amg_settings.add_argument(
"--hourglass_cluster_window_size",
type=int,
default=5,
help="window size in clustering"
)
amg_settings.add_argument(
"--hourglass_reconstruction_k",
type=int,
default=20,
help="k in token reconstruction layer of hourglass vit"
)
def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
*[str(x) for x in mask_data["bbox"]],
*[str(x) for x in mask_data["point_coords"][0]],
str(mask_data["predicted_iou"]),
str(mask_data["stability_score"]),
*[str(x) for x in mask_data["crop_box"]],
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
f.write("\n".join(metadata))
return
def get_amg_kwargs(args):
amg_kwargs = {
"points_per_side": args.points_per_side,
"points_per_batch": args.points_per_batch,
"pred_iou_thresh": args.pred_iou_thresh,
"stability_score_thresh": args.stability_score_thresh,
"stability_score_offset": args.stability_score_offset,
"box_nms_thresh": args.box_nms_thresh,
"crop_n_layers": args.crop_n_layers,
"crop_nms_thresh": args.crop_nms_thresh,
"crop_overlap_ratio": args.crop_overlap_ratio,
"crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
"min_mask_region_area": args.min_mask_region_area,
}
amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
return amg_kwargs
def get_hourglass_kwargs(args):
hourglass_kwargs = {
"use_hourglass": args.use_hourglass,
"hourglass_clustering_location": args.hourglass_clustering_location,
"hourglass_num_cluster": args.hourglass_num_cluster,
"hourglass_cluster_iters": args.hourglass_cluster_iters,
"hourglass_temperture": args.hourglass_temperture,
"hourglass_cluster_window_size": args.hourglass_cluster_window_size,
"hourglass_reconstruction_k": args.hourglass_reconstruction_k,
}
hourglass_kwargs = {k: v for k, v in hourglass_kwargs.items() if v is not None}
return hourglass_kwargs
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))
def main(args: argparse.Namespace) -> None:
print("Loading model...")
hourglass_kwargs = get_hourglass_kwargs(args)
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint, **hourglass_kwargs)
_ = sam.to(device=args.device)
output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
total_time = 0
warmup = 50
num_samples = 200
for i in tqdm(range(num_samples)):
image = np.random.randint(0, 255, size=(1024, 1024, 3), dtype=np.uint8)
start = time.perf_counter()
# masks = generator.generate(image)
with torch.no_grad():
# mask_data = generator._generate_masks(image)
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, generator.crop_n_layers, generator.crop_overlap_ratio
)
# Iterate over image crops
for crop_box, crop_layer_idx in zip(crop_boxes, layer_idxs):
# crop_data = generator._process_crop(image, crop_box, layer_idx, orig_size)
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
generator.predictor.set_image(cropped_im)
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = generator.point_grids[crop_layer_idx] * points_scale
for (points,) in batch_iterator(generator.points_per_batch, points_for_image):
# batch_data = generator._process_batch(points, cropped_im_size, crop_box, orig_size)
transformed_points = generator.predictor.transform.apply_coords(points, cropped_im_size)
in_points = torch.as_tensor(transformed_points, device=generator.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
masks, iou_preds, _ = generator.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
multimask_output=True,
return_logits=True,
)
del masks
del iou_preds
eta = time.perf_counter() - start
if i >= warmup:
total_time += eta
print("Done!")
print(f"Average time per image: {total_time / (num_samples - warmup)} seconds")
if __name__ == "__main__":
args = parser.parse_args()
main(args)