|
--- |
|
language: |
|
- en |
|
tags: |
|
- information retrieval |
|
- embedding model |
|
- visual information retrieval |
|
metrics: |
|
- recall |
|
pipeline_tag: feature-extraction |
|
--- |
|
|
|
# Memex: An OCR-free Visual-Based Document Embedding Model Based on MiniCPM-V-2.0 |
|
|
|
With MiniCPM-Visual-Embedding, it is possible to directly build knowledge base with raw PDF/Book/Document without any OCR technique nor OCR pipeline. The model only takes images as document-side inputs and produce vectors representing document pages. minicpm-visual-embedding-v0 is trained with over 30k paired query - visual document pages, including textual document, visual document, arxiv figures, industry documents, textbooks, ebooks, etc. The performance of minicpm-visual-embedding-v0 is on a par with a text embedding on text-oriented documents, and an advantages on visually-intensive documents. |
|
|
|
[Github Repo](https://github.com/bokesyo/minicpm-visual-embedding) |
|
|
|
![Memex Archtechture](images/memex.png) |
|
|
|
# News |
|
|
|
- 2024-06-27: We released our first visual embedding model checkpoint minicpm-visual-embedding-v0 on [huggingface](https://huggingface.co/RhapsodyAI/minicpm-visual-embedding-v0). |
|
|
|
- 2024-05-08: We [committed](https://github.com/bokesyo/minicpm-visual-embedding) our training code (full-parameter tuning with GradCache and DeepSpeed, supports large batch size across multiple GPUs with zero-stage1) and eval code. |
|
|
|
# Get started |
|
|
|
Pip install all dependencies: |
|
|
|
``` |
|
Pillow==10.1.0 |
|
timm==0.9.10 |
|
torch==2.1.2 |
|
torchvision==0.16.2 |
|
transformers==4.36.0 |
|
sentencepiece==0.1.99 |
|
numpy==1.26.0 |
|
``` |
|
|
|
First you are suggested to git clone this huggingface repo or download repo with `huggingface_cli`. |
|
|
|
```bash |
|
git lfs install |
|
git clone https://huggingface.co/RhapsodyAI/minicpm-visual-embedding-v0 |
|
``` |
|
|
|
or |
|
|
|
```bash |
|
huggingface-cli download RhapsodyAI/minicpm-visual-embedding-v0 |
|
``` |
|
|
|
```python |
|
from transformers import AutoModel |
|
from transformers import AutoTokenizer |
|
from PIL import Image |
|
import torch |
|
|
|
device = 'cuda:0' |
|
|
|
# This function is borrowed from https://huggingface.co/intfloat/e5-mistral-7b-instruct |
|
def last_token_pool(last_hidden_states, attention_mask): |
|
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
|
if left_padding: |
|
return last_hidden_states[:, -1] |
|
else: |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_states.shape[0] |
|
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
|
|
|
# Load model, be sure to substitute `model_path` by your model path |
|
model_path = '/local/path/to/model' |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
model = AutoModel.from_pretrained(model_path, trust_remote_code=True) |
|
model.to(device) |
|
|
|
# Load image to PIL.Image object |
|
image_1 = Image.open('/local/path/to/images/memex.png').convert('RGB') |
|
image_2 = Image.open('/local/path/to/images/us2020.png').convert('RGB') |
|
image_3 = Image.open('/local/path/to/images/hard_negative.png').convert('RGB') |
|
|
|
# User query |
|
query_instruction = 'Represent this query for retrieving relavant document: ' |
|
query = 'Who was elected as president of United States in 2020?' |
|
query_full = query_instruction + query |
|
|
|
# Embed image documents |
|
with torch.no_grad(): |
|
p_outputs = model(text=['', '', ''], image=[image_1, image_2, image_3], tokenizer=tokenizer) |
|
p_reps = last_token_pool(p_outputs.last_hidden_state, p_outputs.attention_mask) |
|
|
|
# Embed text queries |
|
with torch.no_grad(): |
|
q_outputs = model(text=[query_full], image=[None], tokenizer=tokenizer) # [B, s, d] |
|
q_reps = last_token_pool(q_outputs.last_hidden_state, q_outputs.attention_mask) # [B, d] |
|
|
|
# Calculate similarities |
|
scores = torch.matmul(q_reps, p_reps.T) |
|
print(scores) |
|
|
|
# tensor([[0.6506, 4.9630, 3.8614]], device='cuda:0') |
|
|
|
``` |
|
|
|
# Limitations |
|
|
|
Currently, please ensure that dpi of input images be a high value like `300` dpi, a lower dpi like `100` may cause the model performance degrade. We will augment data and fix this in our latest version. |
|
|