Spaces:
Running
Running
# Ultralytics YOLO 🚀, AGPL-3.0 license | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
from ultralytics.data.augment import LetterBox | |
from ultralytics.engine.predictor import BasePredictor | |
from ultralytics.engine.results import Results | |
from ultralytics.utils import DEFAULT_CFG, ops | |
from ultralytics.utils.torch_utils import select_device | |
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score, | |
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks) | |
from .build import build_sam | |
class Predictor(BasePredictor): | |
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): | |
if overrides is None: | |
overrides = {} | |
overrides.update(dict(task='segment', mode='predict', imgsz=1024)) | |
super().__init__(cfg, overrides, _callbacks) | |
# SAM needs retina_masks=True, or the results would be a mess. | |
self.args.retina_masks = True | |
# Args for set_image | |
self.im = None | |
self.features = None | |
# Args for set_prompts | |
self.prompts = {} | |
# Args for segment everything | |
self.segment_all = False | |
def preprocess(self, im): | |
"""Prepares input image before inference. | |
Args: | |
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. | |
""" | |
if self.im is not None: | |
return self.im | |
not_tensor = not isinstance(im, torch.Tensor) | |
if not_tensor: | |
im = np.stack(self.pre_transform(im)) | |
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) | |
im = np.ascontiguousarray(im) # contiguous | |
im = torch.from_numpy(im) | |
img = im.to(self.device) | |
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 | |
if not_tensor: | |
img = (img - self.mean) / self.std | |
return img | |
def pre_transform(self, im): | |
"""Pre-transform input image before inference. | |
Args: | |
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. | |
Return: A list of transformed imgs. | |
""" | |
assert len(im) == 1, 'SAM model has not supported batch inference yet!' | |
return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im] | |
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Args: | |
im (torch.Tensor): The preprocessed image, (N, C, H, W). | |
bboxes (np.ndarray | List, None): (N, 4), in XYXY format. | |
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. | |
labels (np.ndarray | List, None): (N, ), labels for the point prompts. | |
1 indicates a foreground point and 0 indicates a background point. | |
masks (np.ndarray, None): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form (N, H, W), where | |
for SAM, H=W=256. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
Returns: | |
(np.ndarray): The output masks in CxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(np.ndarray): An array of length C containing the model's | |
predictions for the quality of each mask. | |
(np.ndarray): An array of shape CxHxW, where C is the number | |
of masks and H=W=256. These low resolution logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
# Get prompts from self.prompts first | |
bboxes = self.prompts.pop('bboxes', bboxes) | |
points = self.prompts.pop('points', points) | |
masks = self.prompts.pop('masks', masks) | |
if all(i is None for i in [bboxes, points, masks]): | |
return self.generate(im, *args, **kwargs) | |
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) | |
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Args: | |
im (torch.Tensor): The preprocessed image, (N, C, H, W). | |
bboxes (np.ndarray | List, None): (N, 4), in XYXY format. | |
points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels. | |
labels (np.ndarray | List, None): (N, ), labels for the point prompts. | |
1 indicates a foreground point and 0 indicates a background point. | |
masks (np.ndarray, None): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form (N, H, W), where | |
for SAM, H=W=256. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
Returns: | |
(np.ndarray): The output masks in CxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(np.ndarray): An array of length C containing the model's | |
predictions for the quality of each mask. | |
(np.ndarray): An array of shape CxHxW, where C is the number | |
of masks and H=W=256. These low resolution logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
features = self.model.image_encoder(im) if self.features is None else self.features | |
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] | |
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) | |
# Transform input prompts | |
if points is not None: | |
points = torch.as_tensor(points, dtype=torch.float32, device=self.device) | |
points = points[None] if points.ndim == 1 else points | |
# Assuming labels are all positive if users don't pass labels. | |
if labels is None: | |
labels = np.ones(points.shape[0]) | |
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) | |
points *= r | |
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) | |
points, labels = points[:, None, :], labels[:, None] | |
if bboxes is not None: | |
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) | |
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes | |
bboxes *= r | |
if masks is not None: | |
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device) | |
masks = masks[:, None, :, :] | |
points = (points, labels) if points is not None else None | |
# Embed prompts | |
sparse_embeddings, dense_embeddings = self.model.prompt_encoder( | |
points=points, | |
boxes=bboxes, | |
masks=masks, | |
) | |
# Predict masks | |
pred_masks, pred_scores = self.model.mask_decoder( | |
image_embeddings=features, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
) | |
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) | |
# `d` could be 1 or 3 depends on `multimask_output`. | |
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) | |
def generate(self, | |
im, | |
crop_n_layers=0, | |
crop_overlap_ratio=512 / 1500, | |
crop_downscale_factor=1, | |
point_grids=None, | |
points_stride=32, | |
points_batch_size=64, | |
conf_thres=0.88, | |
stability_score_thresh=0.95, | |
stability_score_offset=0.95, | |
crop_nms_thresh=0.7): | |
"""Segment the whole image. | |
Args: | |
im (torch.Tensor): The preprocessed image, (N, C, H, W). | |
crop_n_layers (int): If >0, mask prediction will be run again on | |
crops of the image. Sets the number of layers to run, where each | |
layer has 2**i_layer number of image crops. | |
crop_overlap_ratio (float): Sets the degree to which crops overlap. | |
In the first crop layer, crops will overlap by this fraction of | |
the image length. Later layers with more crops scale down this overlap. | |
crop_downscale_factor (int): The number of points-per-side | |
sampled in layer n is scaled down by crop_n_points_downscale_factor**n. | |
point_grids (list(np.ndarray), None): A list over explicit grids | |
of points used for sampling, normalized to [0,1]. The nth grid in the | |
list is used in the nth crop layer. Exclusive with points_per_side. | |
points_stride (int, None): The number of points to be sampled | |
along one side of the image. The total number of points is | |
points_per_side**2. If None, 'point_grids' must provide explicit | |
point sampling. | |
points_batch_size (int): Sets the number of points run simultaneously | |
by the model. Higher numbers may be faster but use more GPU memory. | |
conf_thres (float): A filtering threshold in [0,1], using the | |
model's predicted mask quality. | |
stability_score_thresh (float): A filtering threshold in [0,1], using | |
the stability of the mask under changes to the cutoff used to binarize | |
the model's mask predictions. | |
stability_score_offset (float): The amount to shift the cutoff when | |
calculated the stability score. | |
crop_nms_thresh (float): The box IoU cutoff used by non-maximal | |
suppression to filter duplicate masks between different crops. | |
""" | |
self.segment_all = True | |
ih, iw = im.shape[2:] | |
crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) | |
if point_grids is None: | |
point_grids = build_all_layer_point_grids( | |
points_stride, | |
crop_n_layers, | |
crop_downscale_factor, | |
) | |
pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] | |
for crop_region, layer_idx in zip(crop_regions, layer_idxs): | |
x1, y1, x2, y2 = crop_region | |
w, h = x2 - x1, y2 - y1 | |
area = torch.tensor(w * h, device=im.device) | |
points_scale = np.array([[w, h]]) # w, h | |
# Crop image and interpolate to input size | |
crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False) | |
# (num_points, 2) | |
points_for_image = point_grids[layer_idx] * points_scale | |
crop_masks, crop_scores, crop_bboxes = [], [], [] | |
for (points, ) in batch_iterator(points_batch_size, points_for_image): | |
pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) | |
# Interpolate predicted masks to input size | |
pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0] | |
idx = pred_score > conf_thres | |
pred_mask, pred_score = pred_mask[idx], pred_score[idx] | |
stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold, | |
stability_score_offset) | |
idx = stability_score > stability_score_thresh | |
pred_mask, pred_score = pred_mask[idx], pred_score[idx] | |
# Bool type is much more memory-efficient. | |
pred_mask = pred_mask > self.model.mask_threshold | |
# (N, 4) | |
pred_bbox = batched_mask_to_box(pred_mask).float() | |
keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) | |
if not torch.all(keep_mask): | |
pred_bbox = pred_bbox[keep_mask] | |
pred_mask = pred_mask[keep_mask] | |
pred_score = pred_score[keep_mask] | |
crop_masks.append(pred_mask) | |
crop_bboxes.append(pred_bbox) | |
crop_scores.append(pred_score) | |
# Do nms within this crop | |
crop_masks = torch.cat(crop_masks) | |
crop_bboxes = torch.cat(crop_bboxes) | |
crop_scores = torch.cat(crop_scores) | |
keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS | |
crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) | |
crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) | |
crop_scores = crop_scores[keep] | |
pred_masks.append(crop_masks) | |
pred_bboxes.append(crop_bboxes) | |
pred_scores.append(crop_scores) | |
region_areas.append(area.expand(len(crop_masks))) | |
pred_masks = torch.cat(pred_masks) | |
pred_bboxes = torch.cat(pred_bboxes) | |
pred_scores = torch.cat(pred_scores) | |
region_areas = torch.cat(region_areas) | |
# Remove duplicate masks between crops | |
if len(crop_regions) > 1: | |
scores = 1 / region_areas | |
keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) | |
pred_masks = pred_masks[keep] | |
pred_bboxes = pred_bboxes[keep] | |
pred_scores = pred_scores[keep] | |
return pred_masks, pred_scores, pred_bboxes | |
def setup_model(self, model, verbose=True): | |
"""Set up YOLO model with specified thresholds and device.""" | |
device = select_device(self.args.device, verbose=verbose) | |
if model is None: | |
model = build_sam(self.args.model) | |
model.eval() | |
self.model = model.to(device) | |
self.device = device | |
self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) | |
self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) | |
# TODO: Temporary settings for compatibility | |
self.model.pt = False | |
self.model.triton = False | |
self.model.stride = 32 | |
self.model.fp16 = False | |
self.done_warmup = True | |
def postprocess(self, preds, img, orig_imgs): | |
"""Postprocesses inference output predictions to create detection masks for objects.""" | |
# (N, 1, H, W), (N, 1) | |
pred_masks, pred_scores = preds[:2] | |
pred_bboxes = preds[2] if self.segment_all else None | |
names = dict(enumerate(str(i) for i in range(len(pred_masks)))) | |
results = [] | |
for i, masks in enumerate([pred_masks]): | |
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs | |
if pred_bboxes is not None: | |
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) | |
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) | |
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) | |
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] | |
masks = masks > self.model.mask_threshold # to bool | |
path = self.batch[0] | |
img_path = path[i] if isinstance(path, list) else path | |
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) | |
# Reset segment-all mode. | |
self.segment_all = False | |
return results | |
def setup_source(self, source): | |
"""Sets up source and inference mode.""" | |
if source is not None: | |
super().setup_source(source) | |
def set_image(self, image): | |
"""Set image in advance. | |
Args: | |
image (str | np.ndarray): image file path or np.ndarray image by cv2. | |
""" | |
if self.model is None: | |
model = build_sam(self.args.model) | |
self.setup_model(model) | |
self.setup_source(image) | |
assert len(self.dataset) == 1, '`set_image` only supports setting one image!' | |
for batch in self.dataset: | |
im = self.preprocess(batch[1]) | |
self.features = self.model.image_encoder(im) | |
self.im = im | |
break | |
def set_prompts(self, prompts): | |
"""Set prompts in advance.""" | |
self.prompts = prompts | |
def reset_image(self): | |
self.im = None | |
self.features = None | |
def remove_small_regions(masks, min_area=0, nms_thresh=0.7): | |
""" | |
Removes small disconnected regions and holes in masks, then reruns | |
box NMS to remove any new duplicates. Requires open-cv as a dependency. | |
Args: | |
masks (torch.Tensor): Masks, (N, H, W). | |
min_area (int): Minimum area threshold. | |
nms_thresh (float): NMS threshold. | |
""" | |
if len(masks) == 0: | |
return masks | |
# Filter small disconnected regions and holes | |
new_masks = [] | |
scores = [] | |
for mask in masks: | |
mask = mask.cpu().numpy() | |
mask, changed = remove_small_regions(mask, min_area, mode='holes') | |
unchanged = not changed | |
mask, changed = remove_small_regions(mask, min_area, mode='islands') | |
unchanged = unchanged and not changed | |
new_masks.append(torch.as_tensor(mask).unsqueeze(0)) | |
# Give score=0 to changed masks and score=1 to unchanged masks | |
# so NMS will prefer ones that didn't need postprocessing | |
scores.append(float(unchanged)) | |
# Recalculate boxes and remove any new duplicates | |
new_masks = torch.cat(new_masks, dim=0) | |
boxes = batched_mask_to_box(new_masks) | |
keep = torchvision.ops.nms( | |
boxes.float(), | |
torch.as_tensor(scores), | |
nms_thresh, | |
) | |
# Only recalculate masks for masks that have changed | |
for i in keep: | |
if scores[i] == 0.0: | |
masks[i] = new_masks[i] | |
return masks[keep] | |