|
--- |
|
library_name: transformers |
|
language: |
|
- vi |
|
- en |
|
base_model: |
|
- 5CD-AI/Vintern-1B-v2 |
|
datasets: |
|
- vidore/colpali_train_set |
|
- 5CD-AI/Viet-Doc-VQA |
|
- 5CD-AI/Viet-OCR-VQA |
|
- 5CD-AI/Viet-Doc-VQA-II |
|
tags: |
|
- colpali |
|
--- |
|
<div align="center"> |
|
<img src="colvintern.png" width="400"/> |
|
</div> |
|
|
|
## ColVintern-1B-v1 🇻🇳 ❄️ - Colpali version for Vietnamese. |
|
|
|
**What's new in ColVintern-1B-v1!** |
|
- We coded and successfully trained the **Colpali pipeline** for **Vintern-1B-v2**. The model supports RAG by extracting embedding vectors for questions and images containing related information. |
|
- This is the first experimental version, trained on the [**Colpali dataset**](https://huggingface.co/datasets/vidore/colpali_train_set) for English and **5%** of the image-based question-answer pairs we have for Vietnamese. |
|
- The model achieves results nearly equivalent to Colpali version 1, with strong support for Vietnamese texts and only 1 billion parameters compared to current 2B-3B Colpali models. |
|
|
|
## Colpali Benchmarks |
|
|
|
We tested on the [**ViDoRe benchmark**](https://huggingface.co/collections/vidore/vidore-benchmark-667173f98e70a1c0fa4db00d) from the Colpali paper. The **TabF** and **Shift** test datasets were not used because they are in French. We plan to expand to multiple languages in the near future. |
|
|
|
| | ArxivQ | DocQ | InfoQ | TATQ | AI | Energy | Gov. | Health. | Avg. | |
|
|:------------------------------:|:--------:|:------:|:-------:|:------:|:------:|:--------:|:-------:|:---------:|:--------:| |
|
| **Unstructured** Text only | | | | | | | | | | |
|
| - BM25 | - | 34.1 | - | 44.0 | 90.4 | 78.3 | 78.8 | 82.6 | - | |
|
| - BGE-M3 | - | 28.4 | - | 36.1 | 88.4 | 76.8 | 77.7 | 84.6 | - | |
|
| **Unstructured** + OCR | | | | | | | | | | |
|
| - BM25 | 31.6 | 36.8 | 62.9 | 62.7 | 92.8 | 85.9 | 83.9 | 87.2 | 68.0 | |
|
| - BGE-M3 | 31.4 | 25.7 | 60.1 | 50.5 | 90.2 | 83.6 | 84.9 | 91.1 | 64.7 | |
|
| **Unstructured** + Captioning | | | | | | | | | | |
|
| - BM25 | 40.1 | 38.4 | 70.0 | 61.5 | 88.0 | 84.7 | 82.7 | 89.2 | 69.3 | |
|
| - BGE-M3 | 35.7 | 32.9 | 71.9 | 43.8 | 88.8 | 83.3 | 80.4 | 91.3 | 66.0 | |
|
| **Contrastive VLMs** | | | | | | | | | | |
|
| - Jina-CLIP | 25.4 | 11.9 | 35.5 | 3.3 | 15.2 | 19.7 | 21.4 | 20.8 | 19.2 | |
|
| - Nomic-vision | 17.1 | 10.7 | 30.1 | 2.7 | 12.9 | 10.9 | 11.4 | 15.7 | 13.9 | |
|
| - SigLIP (Vanilla) | 43.2 | 30.3 | 64.1 | 26.2 | 62.5 | 65.7 | 66.1 | 79.1 | 54.7 | |
|
| **Colpali** | | | | | | | | | | |
|
| - SigLIP (Vanilla) | 43.2 | 30.3 | 64.1 | 26.2 | 62.5 | 65.7 | 66.1 | 79.1 | 54.7 | |
|
| - BiSigLIP (+fine-tuning) | 58.5 | 32.9 | 70.5 | 30.5 | 74.3 | 73.7 | 74.2 | 82.3 | 62.1 | |
|
| - BiPali (+LLM) | 56.5 | 30.0 | 67.4 | 33.4 | 71.2 | 61.9 | 73.8 | 73.6 | 58.5 | |
|
| - ColPali (+Late Inter.) | **79.1** | **54.4** | 81.8 | **65.8** | **96.2** | **91.0** | **92.7** | 94.4 | **81.3** | |
|
| **Ours** | | | | | | | | | | |
|
| - ColVintern-1B (+Late Inter.) | 71.6 | 48.3 | **84.6** | 59.6 | 92.9 | 88.7 | 89.4 | **95.2** | 78.8 | |
|
|
|
We are expanding the training dataset for upcoming versions, including adding hard negative mining techniques, increasing GPU VRAM, etc., to achieve better results. |
|
|
|
## Examples |
|
|
|
Input Images: |
|
<div style="display: flex; gap: 20px;"> |
|
<img src="ex1.jpg" width="300"/> |
|
<img src="ex2.jpg" width="300"/> |
|
</div> |
|
|
|
Input Queries: |
|
``` |
|
queries = ["Cảng Hải Phòng thông báo gì ?","Phí giao hàng bao nhiêu ?"] |
|
``` |
|
|
|
Output Scores: |
|
| Query | Image 1 Score | Image 2 Score | |
|
|--------------------------------------|---------------|---------------| |
|
| Chuyện gì xảy ra với quốc lộ 5 TP Hải Phòng ? | 62.4333 | 59.9523 | |
|
| Phí giao hàng bao nhiêu ? | 60.7748 | 62.8654 | |
|
|
|
|
|
|
|
## Quickstart: |
|
|
|
Colab: https://colab.research.google.com/drive/1-y8HLiyS0oCj7Vpy4i7FsJ1A6kU7ROca?usp=sharing |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor |
|
import matplotlib.pyplot as plt |
|
|
|
model_name = "5CD-AI/ColVintern-1B-v1" |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
model = AutoModel.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
).eval().cuda() |
|
|
|
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex1.jpg |
|
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex2.jpg |
|
|
|
images = [Image.open("ex1.jpg"),Image.open("ex2.jpg")] |
|
batch_images = processor.process_images(images) |
|
|
|
queries = [ |
|
"Cảng Hải Phòng thông báo gì ?", |
|
"Phí giao hàng bao nhiêu ?", |
|
] |
|
|
|
batch_queries = processor.process_queries(queries) |
|
|
|
batch_images["pixel_values"] = batch_images["pixel_values"].cuda().bfloat16() |
|
batch_images["input_ids"] = batch_images["input_ids"].cuda() |
|
batch_images["attention_mask"] = batch_images["attention_mask"].cuda().bfloat16() |
|
batch_queries["input_ids"] = batch_queries["input_ids"].cuda() |
|
batch_queries["attention_mask"] = batch_queries["attention_mask"].cuda().bfloat16() |
|
|
|
with torch.no_grad(): |
|
image_embeddings = model(**batch_images) |
|
query_embeddings = model(**batch_queries) |
|
|
|
scores = processor.score_multi_vector(query_embeddings, image_embeddings) |
|
|
|
max_scores, max_indices = torch.max(scores, dim=1) |
|
# In ra kết quả cho mỗi câu hỏi |
|
for i, query in enumerate(queries): |
|
print(f"Câu hỏi: '{query}'") |
|
print(f"Điểm số: {max_scores[i].item()}\n") |
|
plt.figure(figsize=(5,5)) |
|
plt.imshow(images[max_indices[i]]) |
|
plt.show() |
|
``` |
|
|
|
## Citation |
|
|
|
<!-- ``` |
|
@misc{doan2024vintern1befficientmultimodallarge, |
|
title={Vintern-1B: An Efficient Multimodal Large Language Model for Vietnamese}, |
|
author={Khang T. Doan and Bao G. Huynh and Dung T. Hoang and Thuc D. Pham and Nhat H. Pham and Quan T. M. Nguyen and Bang Q. Vo and Suong N. Hoang}, |
|
year={2024}, |
|
eprint={2408.12480}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.LG}, |
|
url={https://arxiv.org/abs/2408.12480}, |
|
} |
|
``` --> |