|
--- |
|
library_name: transformers |
|
language: |
|
- vi |
|
- en |
|
base_model: |
|
- 5CD-AI/Vintern-1B-v2 |
|
datasets: |
|
- vidore/colpali_train_set |
|
- 5CD-AI/Viet-Doc-VQA |
|
- 5CD-AI/Viet-OCR-VQA |
|
- 5CD-AI/Viet-Doc-VQA-II |
|
tags: |
|
- colpali |
|
--- |
|
<div align="center"> |
|
<img src="colvintern.png" width="400"/> |
|
</div> |
|
|
|
## ColVintern-1B-v1 🇻🇳 ❄️ - Colpali version for Vietnamese. |
|
|
|
**What's new in ColVintern-1B-v1!** |
|
- We coded and successfully trained the **Colpali pipeline** for **Vintern-1B-v2**. The model supports RAG by extracting embedding vectors for questions and images containing related information. |
|
- This is the first experimental version, trained on the [**Colpali dataset**](https://huggingface.co/datasets/vidore/colpali_train_set) for English and **5%** of the image-based question-answer pairs we have for Vietnamese. |
|
- The model achieves results nearly equivalent to Colpali version 1, with strong support for Vietnamese texts and only 1 billion parameters compared to current 2B-3B Colpali models. |
|
|
|
## Colpali Benchmarks |
|
|
|
We tested on the [**ViDoRe benchmark**](https://huggingface.co/collections/vidore/vidore-benchmark-667173f98e70a1c0fa4db00d) from the Colpali paper. The **TabF** and **Shift** test datasets were not used because they are in French. We plan to expand to multiple languages in the near future. |
|
|
|
| | ArxivQ | DocQ | InfoQ | TATQ | AI | Energy | Gov. | Health. | Avg. | |
|
|:------------------------------:|:--------:|:------:|:-------:|:------:|:------:|:--------:|:-------:|:---------:|:--------:| |
|
| **Unstructured** Text only | | | | | | | | | | |
|
| - BM25 | - | 34.1 | - | 44.0 | 90.4 | 78.3 | 78.8 | 82.6 | - | |
|
| - BGE-M3 | - | 28.4 | - | 36.1 | 88.4 | 76.8 | 77.7 | 84.6 | - | |
|
| **Unstructured** + OCR | | | | | | | | | | |
|
| - BM25 | 31.6 | 36.8 | 62.9 | 62.7 | 92.8 | 85.9 | 83.9 | 87.2 | 68.0 | |
|
| - BGE-M3 | 31.4 | 25.7 | 60.1 | 50.5 | 90.2 | 83.6 | 84.9 | 91.1 | 64.7 | |
|
| **Unstructured** + Captioning | | | | | | | | | | |
|
| - BM25 | 40.1 | 38.4 | 70.0 | 61.5 | 88.0 | 84.7 | 82.7 | 89.2 | 69.3 | |
|
| - BGE-M3 | 35.7 | 32.9 | 71.9 | 43.8 | 88.8 | 83.3 | 80.4 | 91.3 | 66.0 | |
|
| **Contrastive VLMs** | | | | | | | | | | |
|
| - Jina-CLIP | 25.4 | 11.9 | 35.5 | 3.3 | 15.2 | 19.7 | 21.4 | 20.8 | 19.2 | |
|
| - Nomic-vision | 17.1 | 10.7 | 30.1 | 2.7 | 12.9 | 10.9 | 11.4 | 15.7 | 13.9 | |
|
| - SigLIP (Vanilla) | 43.2 | 30.3 | 64.1 | 26.2 | 62.5 | 65.7 | 66.1 | 79.1 | 54.7 | |
|
| **Colpali** | | | | | | | | | | |
|
| - SigLIP (Vanilla) | 43.2 | 30.3 | 64.1 | 26.2 | 62.5 | 65.7 | 66.1 | 79.1 | 54.7 | |
|
| - BiSigLIP (+fine-tuning) | 58.5 | 32.9 | 70.5 | 30.5 | 74.3 | 73.7 | 74.2 | 82.3 | 62.1 | |
|
| - BiPali (+LLM) | 56.5 | 30.0 | 67.4 | 33.4 | 71.2 | 61.9 | 73.8 | 73.6 | 58.5 | |
|
| - ColPali (+Late Inter.) | **79.1** | **54.4** | 81.8 | **65.8** | **96.2** | **91.0** | **92.7** | 94.4 | **81.3** | |
|
| **Ours** | | | | | | | | | | |
|
| - ColVintern-1B (+Late Inter.) | 71.6 | 48.3 | **84.6** | 59.6 | 92.9 | 88.7 | 89.4 | **95.2** | 78.8 | |
|
|
|
We are expanding the training dataset for upcoming versions, including adding hard negative mining techniques, increasing GPU VRAM, etc., to achieve better results. |
|
|
|
## Examples |
|
|
|
Input Images: |
|
<div style="display: flex; gap: 20px;"> |
|
<img src="ex1.jpg" width="300"/> |
|
<img src="ex2.jpg" width="300"/> |
|
</div> |
|
|
|
Input Queries: |
|
``` |
|
queries = ["Cảng Hải Phòng thông báo gì ?","Phí giao hàng bao nhiêu ?"] |
|
``` |
|
|
|
Output Scores: |
|
| Query | Image 1 Score | Image 2 Score | |
|
|--------------------------------------|---------------|---------------| |
|
| Chuyện gì xảy ra với quốc lộ 5 TP Hải Phòng ? | 62.4333 | 59.9523 | |
|
| Phí giao hàng bao nhiêu ? | 60.7748 | 62.8654 | |
|
|
|
|
|
|
|
## Quickstart: |
|
|
|
Colab: https://colab.research.google.com/drive/1-y8HLiyS0oCj7Vpy4i7FsJ1A6kU7ROca?usp=sharing |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor |
|
import matplotlib.pyplot as plt |
|
|
|
model_name = "5CD-AI/ColVintern-1B-v1" |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
model = AutoModel.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
).eval().cuda() |
|
|
|
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex1.jpg |
|
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex2.jpg |
|
|
|
images = [Image.open("ex1.jpg"),Image.open("ex2.jpg")] |
|
batch_images = processor.process_images(images) |
|
|
|
queries = [ |
|
"Cảng Hải Phòng thông báo gì ?", |
|
"Phí giao hàng bao nhiêu ?", |
|
] |
|
|
|
batch_queries = processor.process_queries(queries) |
|
|
|
batch_images["pixel_values"] = batch_images["pixel_values"].cuda().bfloat16() |
|
batch_images["input_ids"] = batch_images["input_ids"].cuda() |
|
batch_images["attention_mask"] = batch_images["attention_mask"].cuda().bfloat16() |
|
batch_queries["input_ids"] = batch_queries["input_ids"].cuda() |
|
batch_queries["attention_mask"] = batch_queries["attention_mask"].cuda().bfloat16() |
|
|
|
with torch.no_grad(): |
|
image_embeddings = model(**batch_images) |
|
query_embeddings = model(**batch_queries) |
|
|
|
scores = processor.score_multi_vector(query_embeddings, image_embeddings) |
|
|
|
max_scores, max_indices = torch.max(scores, dim=1) |
|
# In ra kết quả cho mỗi câu hỏi |
|
for i, query in enumerate(queries): |
|
print(f"Câu hỏi: '{query}'") |
|
print(f"Điểm số: {max_scores[i].item()}\n") |
|
plt.figure(figsize=(5,5)) |
|
plt.imshow(images[max_indices[i]]) |
|
plt.show() |
|
``` |
|
|
|
## Citation |
|
|
|
``` |
|
|
|
``` |