|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- test-tcy/VisRAG-Ret-Train-In-domain-data |
|
- test-tcy/VisRAG-Ret-Train-Synthetic-data |
|
language: |
|
- en |
|
base_model: |
|
- openbmb/MiniCPM-V-2 |
|
tags: |
|
- VisRAG |
|
pipeline_tag: feature-extraction |
|
--- |
|
# VisRAG: Vision-based Retrieval-augmented Generation on Multi-modality Documents |
|
**VisRAG** is a novel vision-language model (VLM)-based RAG pipeline. In this pipeline, instead of first parsing the document to obtain text, the document is directly embedded using a VLM as an image and then retrieved to enhance the generation of a VLM.Compared to traditional text-based RAG, **VisRAG** maximizes the retention and utilization of the data information in the original documents, eliminating the information loss introduced during the parsing process. |
|
<p align="center"><img width=800 src="https://github.com/openbmb/VisRAG/blob/master/assets/main_figure.png?raw=true"/></p> |
|
|
|
## VisRAG Description |
|
|
|
### VisRAG-Ret |
|
**VisRAG-Ret** is a document embedding model built on [MiniCPM-V 2.0](https://huggingface.co/openbmb/MiniCPM-V-2), a vision-language model that integrates [SigLIP](https://huggingface.co/google/siglip-so400m-patch14-384) as the vision encoder and [MiniCPM-2B](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) as the language model. |
|
|
|
### VisRAG-Gen |
|
In the paper, We use MiniCPM-V 2.0, MiniCPM-V 2.6 and GPT-4o as the generators. Actually you can use any VLMs you like! |
|
|
|
## Training |
|
|
|
### VisRAG-Ret |
|
Our training dataset of 362,110 Query-Document (Q-D) Pairs for **VisRAG-Ret** is comprised of train sets of openly available academic datasets (34%) and a synthetic dataset made up of pages from web-crawled PDF documents and augmented with VLM-generated (GPT-4o) pseudo-queries (66%). |
|
|
|
### VisRAG-Gen |
|
The generation part does not use any fine-tuning; we directly use off-the-shelf LLMs/VLMs for generation. |
|
|
|
## Implementation Details |
|
**VisRAG-Ret** is fine-tuned using [in-batch negatives](https://arxiv.org/abs/2004.04906) for one epoch with a batch size of 128 on 8 NVIDIA A100 80GB GPUs. The temperature is set to 0.02. |
|
|
|
## Requirements |
|
``` |
|
torch==2.1.2 |
|
torchvision==0.16.2 |
|
transformers==4.40.2 |
|
sentencepiece==0.1.99 |
|
decord==0.6.0 |
|
Pillow==10.1.0 |
|
accelerate==0.27.0 |
|
deepspeed==0.13.2 |
|
protobuf==4.25.0 |
|
pytrec_eval==0.5 |
|
``` |
|
|
|
## Usage |
|
|
|
### VisRAG-Ret |
|
```python |
|
from transformers import AutoModel, AutoTokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import os |
|
|
|
def weighted_mean_pooling(hidden, attention_mask): |
|
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
|
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1) |
|
d = attention_mask_.sum(dim=1, keepdim=True).float() |
|
reps = s / d |
|
return reps |
|
|
|
@torch.no_grad() |
|
def encode(text_or_image_list): |
|
|
|
if (isinstance(text_or_image_list[0], str)): |
|
inputs = { |
|
"text": text_or_image_list, |
|
'image': [None] * len(text_or_image_list), |
|
'tokenizer': tokenizer |
|
} |
|
else: |
|
inputs = { |
|
"text": [''] * len(text_or_image_list), |
|
'image': text_or_image_list, |
|
'tokenizer': tokenizer |
|
} |
|
outputs = model(**inputs) |
|
attention_mask = outputs.attention_mask |
|
hidden = outputs.last_hidden_state |
|
|
|
reps = weighted_mean_pooling(hidden, attention_mask) |
|
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy() |
|
return embeddings |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("openbmb/VisRAG", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("openbmb/VisRAG", trust_remote_code=True) |
|
model.eval() |
|
|
|
script_dir = os.path.dirname(os.path.realpath(__file__)) |
|
queries = ["What does a dog look like?"] |
|
passages = [ |
|
Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'), |
|
Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'), |
|
] |
|
|
|
INSTRUCTION = "Represent this query for retrieving relavant documents: " |
|
queries = [INSTRUCTION + query for query in queries] |
|
|
|
embeddings_query = encode(queries) |
|
embeddings_doc = encode(passages) |
|
|
|
scores = (embeddings_query @ embeddings_doc.T) |
|
print(scores.tolist()) |
|
``` |
|
|
|
## License |
|
|
|
* The code in this repo is released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License. |
|
* The usage of **VisRAG-Ret** model weights must strictly follow [MiniCPM Model License.md](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md). |
|
* The models and weights of **VisRAG-Ret** are completely free for academic research. After filling out a ["questionnaire"](https://modelbest.feishu.cn/share/base/form/shrcnpV5ZT9EJ6xYjh3Kx0J6v8g) for registration, **VisRAG-Ret** weights are also available for free commercial use. |
|
|
|
## Contact |
|
|
|
- Shi Yu: [email protected] |
|
- Chaoyue Tang: [email protected] |
|
|
|
## Citation |
|
|
|
If you use any datasets or models from this organization in your research, please cite the original dataset as follows: |