robustsam-vit-large / README.md
jadechoghari's picture
Update README.md
baf33c3 verified
---
library_name: transformers
license: mit
---
# RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight)
# Model Card for ViT Large (ViT-L) version
<a href="https://colab.research.google.com/drive/1mrOjUNFrfZ2vuTnWrfl9ebAQov3a9S6E?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
[![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/robustsam/robustsam/tree/main)
Official repository for RobustSAM: Segment Anything Robustly on Degraded Images
[Project Page](https://robustsam.github.io/) | [Paper](https://arxiv.org/abs/2406.09627) | [Dataset](https://huggingface.co/robustsam/robustsam/tree/main/dataset)
## Introduction
Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization.
Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring.
**Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything).
# Model Details
The RobustSAM model is made up of 3 modules:
- The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
- The `PromptEncoder`: generates embeddings for points and bounding boxes
- The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
- The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`.
# Usage
## Prompted-Mask-Generation
```python
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
# load the RobustSAM model and processor
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-large")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-large")
# load an image from a url
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
# we define input points (2D localization of an object in the image)
input_points = [[[450, 600]]] # example point
```
```python
# process the image and input points
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
# generate masks using the model
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
```
Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example!
## Automatic-Mask-Generation
The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points
which are all fed to the model.
The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
```python
from transformers import pipeline
# initialize the pipeline for mask generation
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-large", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
```
Now to display the generated mask on the image:
```python
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# simple function to display the mask
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
# get the height and width from the mask
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
# display the original image
plt.imshow(np.array(raw_image))
ax = plt.gca()
# loop through the masks and display each one
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
# show the image with the masks
plt.show()
```
## Visual Comparison
<table>
<tr>
<td>
<img src="figures/gif_output/blur_back_n_forth.gif" width="380">
</td>
<td>
<img src="figures/gif_output/haze_back_n_forth.gif" width="380">
</td>
</tr>
<tr>
<td>
<img src="figures/gif_output/lowlight_back_n_forth.gif" width="380">
</td>
<td>
<img src="figures/gif_output/rain_back_n_forth.gif" width="380">
</td>
</tr>
</table>
<img width="1096" alt="image" src='figures/qualitative_result.PNG'>
## Reference
If you find this work useful, please consider citing us!
```python
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
```
## Acknowledgements
We thank the authors of [SAM](https://github.com/facebookresearch/segment-anything) from which our repo is based off of.