lisa-on-cuda / README.md
Enderfga's picture
fix flash-attn install
a6dd565
|
raw
history blame
14.1 kB

LISA: Reasoning Segmentation via Large Language Model

LISA: Large Language Instructed Segmentation Assistant

Input Output Input Output
"Who was the president of the US in this image? Please output segmentation mask and explain the reason." "Sure, the segmentation result is [SEG]. The President of the United States in the image is President Obama. "Who was the president of the US in this image? Please output segmentation mask and explain why." "Sure, [SEG]. In the image, the President of the United States is President Trump."
"What can make the woman stand higher? Please output segmentation mask and explain why." "Sure, [SEG]. The woman is standing higher by using a ladder..." "Can you segment the camera lens that is more suitable for photographing nearby objects in this image?" "Sure, [SEG]."
"Can you segment the unusual part in this image and explain why." "Sure, it is [SEG]. In the image, the unusual part is the dog wearing a reindeer antler headband..." "Where to wash hands in this image? Please output segmentation mask." "Sure, [SEG]."
"Can you segment the founder of Alibaba in this image and explain why?" "Sure, it is [SEG]. In the image, the man wearing a red tie and sitting on one of the chairs is Jack Ma, the co-founder of Alibaba Group..." "Please segment Lisa in this figure." "Sure, [SEG]."

News

LISA: Reasoning Segmentation Via Large Language Model [Paper]
Xin Lai, Zhuotao Tian, Yukang Chen, Yanwei Li, Yuhui Yuan, Shu Liu, Jiaya Jia

Abstract

In this work, we propose a new segmentation task --- reasoning segmentation. The task is designed to output a segmentation mask given a complex and implicit query text. We establish a benchmark comprising over one thousand image-instruction pairs, incorporating intricate reasoning and world knowledge for evaluation purposes. Finally, we present LISA: Large-language Instructed Segmentation Assistant, which inherits the language generation capabilities of the multi-modal Large Language Model (LLM) while also possessing the ability to produce segmentation masks. For more details, please refer to the paper.

Highlights

LISA unlocks the new segmentation capabilities of multi-modal LLMs, and can handle cases involving:

  1. complex reasoning;
  2. world knowledge;
  3. explanatory answers;
  4. multi-turn conversation.

LISA also demonstrates robust zero-shot capability when trained exclusively on reasoning-free datasets. In addition, fine-tuning the model with merely 239 reasoning segmentation image-instruction pairs results in further performance enhancement.

Experimental results

Installation

pip install -r requirements.txt
pip install flash-attn --no-build-isolation

Training

Training Data Preparation

The training data consists of 4 types of data:

  1. Semantic segmentation datasets: ADE20K, COCO-Stuff, Mapillary, PACO-LVIS, PASCAL-Part

  2. Referring segmentation datasets: refCOCO, refCOCO+, refCOCOg, refCLEF (saiapr_tc-12)

    Note: the origianl links of refCOCO series data are down, and we update them with new ones

  3. Visual Question Answering dataset: LLaVA-Instruct-150k

  4. Reasoning segmentation dataset: ReasonSeg

Download them from the above links, and organize them as follows.

├── dataset
│   ├── ade20k
│   │   ├── annotations
│   │   └── images
│   ├── coco
│   │   └── train2017
│   ├── cocostuff
│   │   ├── annotations
│   │   └── train2017
│   ├── llava_dataset
│   │   └── llava_instruct_150k.json
│   ├── mapillary
│   │   ├── config_v2.0.json
│   │   ├── testing
│   │   ├── training
│   │   └── validation
│   ├── reason_seg
│   │   └── ReasonSeg
│   │       ├── train
│   │       ├── val
│   │       └── explanatory
│   ├── refer_seg
│   │   ├── images
│   │   |   ├── saiapr_tc-12 
│   │   |   └── mscoco
│   │   |       └── images
│   │   |           └── train2014
│   │   ├── refclef
│   │   ├── refcoco
│   │   ├── refcoco+
│   │   └── refcocog
│   └── vlpart
│       ├── paco
│       │   └── annotations
│       └── pascal_part
│           ├── train.json
│           └── VOCdevkit

Pre-trained weights

LLaVA

To train LISA-7B or 13B, you need to follow the instruction to merge the LLaVA delta weights. Typically, we use the final weights LLaVA-Lightning-7B-v1-1 and LLaVA-13B-v1-1 merged from liuhaotian/LLaVA-Lightning-7B-delta-v1-1 and liuhaotian/LLaVA-13b-delta-v1-1, respectively. For Llama2, we can directly use the LLaVA full weights liuhaotian/llava-llama-2-13b-chat-lightning-preview.

SAM ViT-H weights

Download SAM ViT-H pre-trained weights from the link.

Training

deepspeed --master_port=24999 train_ds.py \
  --version="PATH_TO_LLaVA" \
  --dataset_dir='./dataset' \
  --vision_pretrained="PATH_TO_SAM" \
  --dataset="sem_seg||refer_seg||vqa||reason_seg" \
  --sample_rates="9,3,3,1" \
  --exp_name="lisa-7b"

When training is finished, to get the full model weight:

cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin

Validation

deepspeed --master_port=24999 train_ds.py \
  --version="PATH_TO_LLaVA" \
  --dataset_dir='./dataset' \
  --vision_pretrained="PATH_TO_SAM" \
  --exp_name="lisa-7b" \
  --weight='PATH_TO_pytorch_model.bin' \
  --eval_only

Inference

To chat with LISA-13B-llama2-v0 or LISA-13B-llama2-v0-explanatory: (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.)

CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0'
CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0-explanatory'

To use bf16 or fp16 data type for inference:

CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='bf16'

To use 8bit or 4bit data type for inference (this enables running 13B model on a single 24G or 12G GPU at some cost of generation quality):

CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='fp16' --load_in_8bit
CUDA_VISIBLE_DEVICES=0 python3 chat.py --version='xinlai/LISA-13B-llama2-v0' --precision='fp16' --load_in_4bit

After that, input the text prompt and then the image path. For example,

- Please input your prompt: Where can the driver see the car speed in this image? Please output segmentation mask.
- Please input the image path: imgs/example1.jpg

- Please input your prompt: Can you segment the food that tastes spicy and hot?
- Please input the image path: imgs/example2.jpg

The results should be like:

Dataset

In ReasonSeg, we have collected 1218 images (239 train, 200 val, and 779 test). The training and validation sets can be download from this link.

Each image is provided with an annotation JSON file:

image_1.jpg, image_1.json
image_2.jpg, image_2.json
...
image_n.jpg, image_n.json

Important keys contained in JSON files:

- "text": text instructions.
- "is_sentence": whether the text instructions are long sentences.
- "shapes": target polygons.

The elements of the "shapes" exhibit two categories, namely "target" and "ignore". The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process.

We provide a script that demonstrates how to process the annotations:

python3 utils/data_processing.py

Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have more than one instructions (but fewer than six) in the "text" field. During training, users may randomly select one as the text query to obtain a better model.

Citation

If you find this project useful in your research, please consider citing:

@article{reason_seg,
  title={LISA: Reasoning Segmentation via Large Language Model},
  author={Xin Lai and Zhuotao Tian and Yukang Chen and Yanwei Li and Yuhui Yuan and Shu Liu and Jiaya Jia},
  journal={arXiv:2308.00692},
  year={2023}
}

Acknowledgement

  • This work is built upon the LLaVA and SAM.