ColVintern-1B-v1 / README.md
khang119966's picture
Update README.md
d345d0e verified
|
raw
history blame
6.75 kB
metadata
library_name: transformers
language:
  - vi
  - en
base_model:
  - 5CD-AI/Vintern-1B-v2
datasets:
  - vidore/colpali_train_set
  - 5CD-AI/Viet-Doc-VQA
  - 5CD-AI/Viet-OCR-VQA
  - 5CD-AI/Viet-Doc-VQA-II
tags:
  - colpali

ColVintern-1B-v1 🇻🇳 ❄️ - Colpali version for Vietnamese.

What's new in ColVintern-1B-v1!

  • We coded and successfully trained the Colpali pipeline for Vintern-1B-v2. The model supports RAG by extracting embedding vectors for questions and images containing related information.
  • This is the first experimental version, trained on the Colpali dataset for English and 5% of the image-based question-answer pairs we have for Vietnamese.
  • The model achieves results nearly equivalent to Colpali version 1, with strong support for Vietnamese texts and only 1 billion parameters compared to current 2B-3B Colpali models.

Colpali Benchmarks

We tested on the ViDoRe benchmark from the Colpali paper. The TabF and Shift test datasets were not used because they are in French. We plan to expand to multiple languages in the near future.

ArxivQ DocQ InfoQ TATQ AI Energy Gov. Health. Avg.
Unstructured Text only
- BM25 - 34.1 - 44.0 90.4 78.3 78.8 82.6 -
- BGE-M3 - 28.4 - 36.1 88.4 76.8 77.7 84.6 -
Unstructured + OCR
- BM25 31.6 36.8 62.9 62.7 92.8 85.9 83.9 87.2 68.0
- BGE-M3 31.4 25.7 60.1 50.5 90.2 83.6 84.9 91.1 64.7
Unstructured + Captioning
- BM25 40.1 38.4 70.0 61.5 88.0 84.7 82.7 89.2 69.3
- BGE-M3 35.7 32.9 71.9 43.8 88.8 83.3 80.4 91.3 66.0
Contrastive VLMs
- Jina-CLIP 25.4 11.9 35.5 3.3 15.2 19.7 21.4 20.8 19.2
- Nomic-vision 17.1 10.7 30.1 2.7 12.9 10.9 11.4 15.7 13.9
- SigLIP (Vanilla) 43.2 30.3 64.1 26.2 62.5 65.7 66.1 79.1 54.7
Colpali
- SigLIP (Vanilla) 43.2 30.3 64.1 26.2 62.5 65.7 66.1 79.1 54.7
- BiSigLIP (+fine-tuning) 58.5 32.9 70.5 30.5 74.3 73.7 74.2 82.3 62.1
- BiPali (+LLM) 56.5 30.0 67.4 33.4 71.2 61.9 73.8 73.6 58.5
- ColPali (+Late Inter.) 79.1 54.4 81.8 65.8 96.2 91.0 92.7 94.4 81.3
Ours
- ColVintern-1B (+Late Inter.) 71.6 48.3 84.6 59.6 92.9 88.7 89.4 95.2 78.8

We are expanding the training dataset for upcoming versions, including adding hard negative mining techniques, increasing GPU VRAM, etc., to achieve better results.

Examples

Input Images:

Input Queries:

queries = ["Cảng Hải Phòng thông báo gì ?","Phí giao hàng bao nhiêu ?"]

Output Scores:

Query Image 1 Score Image 2 Score
Chuyện gì xảy ra với quốc lộ 5 TP Hải Phòng ? 62.4333 59.9523
Phí giao hàng bao nhiêu ? 60.7748 62.8654

Quickstart:

Colab: https://colab.research.google.com/drive/1-y8HLiyS0oCj7Vpy4i7FsJ1A6kU7ROca?usp=sharing

import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer, AutoProcessor
import matplotlib.pyplot as plt

model_name = "5CD-AI/ColVintern-1B-v1"

processor =  AutoProcessor.from_pretrained(
    model_name,
    trust_remote_code=True
)
model = AutoModel.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
).eval().cuda()

#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex1.jpg
#!wget https://huggingface.co/5CD-AI/ColVintern-1B-v1/resolve/main/ex2.jpg

images = [Image.open("ex1.jpg"),Image.open("ex2.jpg")]
batch_images = processor.process_images(images)

queries = [
    "Cảng Hải Phòng thông báo gì ?",
    "Phí giao hàng bao nhiêu ?",
]

batch_queries = processor.process_queries(queries) 

batch_images["pixel_values"] =  batch_images["pixel_values"].cuda().bfloat16()
batch_images["input_ids"] = batch_images["input_ids"].cuda() 
batch_images["attention_mask"] = batch_images["attention_mask"].cuda().bfloat16()
batch_queries["input_ids"] = batch_queries["input_ids"].cuda() 
batch_queries["attention_mask"] = batch_queries["attention_mask"].cuda().bfloat16()

with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)

max_scores, max_indices = torch.max(scores, dim=1)
# In ra kết quả cho mỗi câu hỏi
for i, query in enumerate(queries):
    print(f"Câu hỏi: '{query}'")
    print(f"Điểm số: {max_scores[i].item()}\n")
    plt.figure(figsize=(5,5))
    plt.imshow(images[max_indices[i]])
    plt.show()

Citation