|
--- |
|
library_name: transformers |
|
license: mit |
|
--- |
|
|
|
# RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight) |
|
# Model Card for ViT Large (ViT-L) version |
|
|
|
<a href="https://colab.research.google.com/drive/1mrOjUNFrfZ2vuTnWrfl9ebAQov3a9S6E?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> |
|
[![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/robustsam/robustsam/tree/main) |
|
|
|
Official repository for RobustSAM: Segment Anything Robustly on Degraded Images |
|
|
|
[Project Page](https://robustsam.github.io/) | [Paper](https://arxiv.org/abs/2406.09627) | [Dataset](https://huggingface.co/robustsam/robustsam/tree/main/dataset) |
|
|
|
|
|
## Introduction |
|
Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization. |
|
|
|
Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring. |
|
|
|
|
|
**Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything). |
|
|
|
# Model Details |
|
|
|
The RobustSAM model is made up of 3 modules: |
|
- The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used. |
|
- The `PromptEncoder`: generates embeddings for points and bounding boxes |
|
- The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed |
|
- The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`. |
|
# Usage |
|
|
|
|
|
## Prompted-Mask-Generation |
|
|
|
```python |
|
from PIL import Image |
|
import requests |
|
from transformers import AutoProcessor, AutoModelForMaskGeneration |
|
|
|
# load the RobustSAM model and processor |
|
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-large") |
|
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-large") |
|
|
|
# load an image from a url |
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" |
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") |
|
|
|
# we define input points (2D localization of an object in the image) |
|
input_points = [[[450, 600]]] # example point |
|
|
|
``` |
|
|
|
|
|
```python |
|
# process the image and input points |
|
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda") |
|
|
|
# generate masks using the model |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) |
|
scores = outputs.iou_scores |
|
|
|
``` |
|
Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844). |
|
For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example! |
|
|
|
## Automatic-Mask-Generation |
|
|
|
The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points |
|
which are all fed to the model. |
|
|
|
The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument) |
|
```python |
|
from transformers import pipeline |
|
|
|
# initialize the pipeline for mask generation |
|
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-large", device=0, points_per_batch=256) |
|
|
|
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" |
|
outputs = generator(image_url, points_per_batch=256) |
|
``` |
|
Now to display the generated mask on the image: |
|
```python |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import numpy as np |
|
|
|
# simple function to display the mask |
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
|
|
|
# get the height and width from the mask |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
# display the original image |
|
plt.imshow(np.array(raw_image)) |
|
ax = plt.gca() |
|
|
|
# loop through the masks and display each one |
|
for mask in outputs["masks"]: |
|
show_mask(mask, ax=ax, random_color=True) |
|
|
|
plt.axis("off") |
|
|
|
# show the image with the masks |
|
plt.show() |
|
``` |
|
|
|
## Visual Comparison |
|
<table> |
|
<tr> |
|
<td> |
|
<img src="figures/gif_output/blur_back_n_forth.gif" width="380"> |
|
</td> |
|
<td> |
|
<img src="figures/gif_output/haze_back_n_forth.gif" width="380"> |
|
</td> |
|
</tr> |
|
<tr> |
|
<td> |
|
<img src="figures/gif_output/lowlight_back_n_forth.gif" width="380"> |
|
</td> |
|
<td> |
|
<img src="figures/gif_output/rain_back_n_forth.gif" width="380"> |
|
</td> |
|
</tr> |
|
</table> |
|
|
|
<img width="1096" alt="image" src='figures/qualitative_result.PNG'> |
|
|
|
## Reference |
|
If you find this work useful, please consider citing us! |
|
```python |
|
@inproceedings{chen2024robustsam, |
|
title={RobustSAM: Segment Anything Robustly on Degraded Images}, |
|
author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian}, |
|
journal={CVPR}, |
|
year={2024} |
|
} |
|
``` |
|
|
|
|
|
## Acknowledgements |
|
We thank the authors of [SAM](https://github.com/facebookresearch/segment-anything) from which our repo is based off of. |