# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import cv2 # type: ignore from segment_anything import SamAutomaticMaskGenerator, sam_model_registry from segment_anything.utils.amg import ( batch_iterator, generate_crop_boxes, ) import argparse import json import os from typing import Any, Dict, List import numpy as np import matplotlib.pyplot as plt import time import torch from tqdm import tqdm parser = argparse.ArgumentParser( description=( "Runs automatic mask generation on an input image or directory of images, " "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, " "as well as pycocotools if saving in RLE format." ) ) parser.add_argument( "--input", type=str, required=True, help="Path to either a single input image or folder of images.", ) parser.add_argument( "--output", type=str, required=True, help=( "Path to the directory where masks will be output. Output will be either a folder " "of PNGs per image or a single json with COCO-style masks." ), ) parser.add_argument( "--model-type", type=str, default="default", help="The type of model to load, in ['default', 'vit_l', 'vit_b']", ) parser.add_argument( "--checkpoint", type=str, required=True, help="The path to the SAM checkpoint to use for mask generation.", ) parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.") parser.add_argument( "--convert-to-rle", action="store_true", help=( "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. " "Requires pycocotools." ), ) amg_settings = parser.add_argument_group("AMG Settings") amg_settings.add_argument( "--points-per-side", type=int, default=None, help="Generate masks by sampling a grid over the image with this many points to a side.", ) amg_settings.add_argument( "--points-per-batch", type=int, default=None, help="How many input points to process simultaneously in one batch.", ) amg_settings.add_argument( "--pred-iou-thresh", type=float, default=None, help="Exclude masks with a predicted score from the model that is lower than this threshold.", ) amg_settings.add_argument( "--stability-score-thresh", type=float, default=None, help="Exclude masks with a stability score lower than this threshold.", ) amg_settings.add_argument( "--stability-score-offset", type=float, default=None, help="Larger values perturb the mask more when measuring stability score.", ) amg_settings.add_argument( "--box-nms-thresh", type=float, default=None, help="The overlap threshold for excluding a duplicate mask.", ) amg_settings.add_argument( "--crop-n-layers", type=int, default=None, help=( "If >0, mask generation is run on smaller crops of the image to generate more masks. " "The value sets how many different scales to crop at." ), ) amg_settings.add_argument( "--crop-nms-thresh", type=float, default=None, help="The overlap threshold for excluding duplicate masks across different crops.", ) amg_settings.add_argument( "--crop-overlap-ratio", type=int, default=None, help="Larger numbers mean image crops will overlap more.", ) amg_settings.add_argument( "--crop-n-points-downscale-factor", type=int, default=None, help="The number of points-per-side in each layer of crop is reduced by this factor.", ) amg_settings.add_argument( "--min-mask-region-area", type=int, default=None, help=( "Disconnected mask regions or holes with area smaller than this value " "in pixels are removed by postprocessing." ), ) # add hourglass settings amg_settings.add_argument( "--use_hourglass", action="store_true", help="Use hourglass method to expedite mask generation.", ) amg_settings.add_argument( "--hourglass_clustering_location", type=int, default=6, help="location of clustering, ranging from [0, num of layers of transformer)" ) amg_settings.add_argument( "--hourglass_num_cluster", type=int, default=100, help="num of clusters, no more than total number of features" ) amg_settings.add_argument( "--hourglass_cluster_iters", type=int, default=5, help="num of iterations in clustering" ) amg_settings.add_argument( "--hourglass_temperture", type=float, default=5e-3, help="temperture in clustering and reconstruction" ) amg_settings.add_argument( "--hourglass_cluster_window_size", type=int, default=5, help="window size in clustering" ) amg_settings.add_argument( "--hourglass_reconstruction_k", type=int, default=20, help="k in token reconstruction layer of hourglass vit" ) def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None: header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa metadata = [header] for i, mask_data in enumerate(masks): mask = mask_data["segmentation"] filename = f"{i}.png" cv2.imwrite(os.path.join(path, filename), mask * 255) mask_metadata = [ str(i), str(mask_data["area"]), *[str(x) for x in mask_data["bbox"]], *[str(x) for x in mask_data["point_coords"][0]], str(mask_data["predicted_iou"]), str(mask_data["stability_score"]), *[str(x) for x in mask_data["crop_box"]], ] row = ",".join(mask_metadata) metadata.append(row) metadata_path = os.path.join(path, "metadata.csv") with open(metadata_path, "w") as f: f.write("\n".join(metadata)) return def get_amg_kwargs(args): amg_kwargs = { "points_per_side": args.points_per_side, "points_per_batch": args.points_per_batch, "pred_iou_thresh": args.pred_iou_thresh, "stability_score_thresh": args.stability_score_thresh, "stability_score_offset": args.stability_score_offset, "box_nms_thresh": args.box_nms_thresh, "crop_n_layers": args.crop_n_layers, "crop_nms_thresh": args.crop_nms_thresh, "crop_overlap_ratio": args.crop_overlap_ratio, "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor, "min_mask_region_area": args.min_mask_region_area, } amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None} return amg_kwargs def get_hourglass_kwargs(args): hourglass_kwargs = { "use_hourglass": args.use_hourglass, "hourglass_clustering_location": args.hourglass_clustering_location, "hourglass_num_cluster": args.hourglass_num_cluster, "hourglass_cluster_iters": args.hourglass_cluster_iters, "hourglass_temperture": args.hourglass_temperture, "hourglass_cluster_window_size": args.hourglass_cluster_window_size, "hourglass_reconstruction_k": args.hourglass_reconstruction_k, } hourglass_kwargs = {k: v for k, v in hourglass_kwargs.items() if v is not None} return hourglass_kwargs def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35))) def main(args: argparse.Namespace) -> None: print("Loading model...") hourglass_kwargs = get_hourglass_kwargs(args) sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint, **hourglass_kwargs) _ = sam.to(device=args.device) output_mode = "coco_rle" if args.convert_to_rle else "binary_mask" amg_kwargs = get_amg_kwargs(args) generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs) total_time = 0 warmup = 50 num_samples = 200 for i in tqdm(range(num_samples)): image = np.random.randint(0, 255, size=(1024, 1024, 3), dtype=np.uint8) start = time.perf_counter() # masks = generator.generate(image) with torch.no_grad(): # mask_data = generator._generate_masks(image) orig_size = image.shape[:2] crop_boxes, layer_idxs = generate_crop_boxes( orig_size, generator.crop_n_layers, generator.crop_overlap_ratio ) # Iterate over image crops for crop_box, crop_layer_idx in zip(crop_boxes, layer_idxs): # crop_data = generator._process_crop(image, crop_box, layer_idx, orig_size) x0, y0, x1, y1 = crop_box cropped_im = image[y0:y1, x0:x1, :] cropped_im_size = cropped_im.shape[:2] generator.predictor.set_image(cropped_im) points_scale = np.array(cropped_im_size)[None, ::-1] points_for_image = generator.point_grids[crop_layer_idx] * points_scale for (points,) in batch_iterator(generator.points_per_batch, points_for_image): # batch_data = generator._process_batch(points, cropped_im_size, crop_box, orig_size) transformed_points = generator.predictor.transform.apply_coords(points, cropped_im_size) in_points = torch.as_tensor(transformed_points, device=generator.predictor.device) in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) masks, iou_preds, _ = generator.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], multimask_output=True, return_logits=True, ) del masks del iou_preds eta = time.perf_counter() - start if i >= warmup: total_time += eta print("Done!") print(f"Average time per image: {total_time / (num_samples - warmup)} seconds") if __name__ == "__main__": args = parser.parse_args() main(args)