PyTorch
English
Tevatron
phi3_v
vidore
custom_code
File size: 5,641 Bytes
8a79123
58844a5
 
0fcc10f
2081611
 
58844a5
 
 
9bd2bdd
58844a5
4435084
8a79123
 
58b17d0
8a79123
58b17d0
8a79123
84ba6e6
 
 
 
 
f5c38c6
 
 
 
 
 
 
63672e5
f5c38c6
58b17d0
 
f5c38c6
 
 
 
 
 
 
 
8a79123
62588ff
8a79123
f5c38c6
137fda9
f5c38c6
137fda9
 
f5c38c6
 
 
62588ff
8a79123
f5c38c6
 
3f8b454
 
 
 
58b17d0
 
3f8b454
 
 
 
 
 
 
f5c38c6
 
137fda9
f5c38c6
3f8b454
f5c38c6
137fda9
 
 
 
 
f5c38c6
137fda9
f5c38c6
 
 
 
 
 
137fda9
 
f5c38c6
137fda9
 
 
 
f5c38c6
 
62588ff
f5c38c6
 
3f8b454
 
b2e183b
3f8b454
f5c38c6
 
137fda9
 
f5c38c6
8a79123
137fda9
 
 
 
32de26d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
language:
- en
- tr
license: mit
library_name: Tevatron
tags:
- vidore
datasets:
- Tevatron/docmatix-ir
- HuggingFaceM4/Docmatix
- Tevatron/msmarco-passage-aug
---

# DSE-Phi3-Docmatix-V1

DSE-Phi3-Docmatix-V1 is a bi-encoder model designed to encode document screenshots into dense vectors for document retrieval. The Document Screenshot Embedding ([DSE](https://arxiv.org/abs/2406.11251)) approach captures documents in their original visual format, preserving all information such as text, images, and layout, thus avoiding tedious parsing and potential information loss.

The model, `Tevatron/dse-phi3-docmatix-v1`, is trained using 1/10 of the `Tevatron/docmatix-ir` dataset, a variant of `HuggingFaceM4/Docmatix` specifically adapted for training PDF retrievers with Vision Language Models in open-domain question answering scenarios. For more information on dataset filtering and hard negative mining, refer to the [docmatix-ir](https://huggingface.co/datasets/Tevatron/docmatix-ir/blob/main/README.md) dataset page.

DSE has strong zero-shot effectiveness for document retrieval both with visual input and text input.
For example, DSE-Phi3-Docmatix-V1 achieves 74.1 nDCG@5 on [ViDoRE](https://huggingface.co/spaces/vidore/vidore-leaderboard) leaderboard in **zero-shot setting** (without finetuning with ViDoRe training data).


## How to Use the Model

### Load the Model and Processor

```python
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

processor = AutoProcessor.from_pretrained('Tevatron/dse-phi3-docmatix-v1', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('Tevatron/dse-phi3-docmatix-v1', trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, use_cache=False).to('cuda:0')

def get_embedding(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    sequence_lengths = attention_mask.sum(dim=1) - 1
    bs = last_hidden_state.shape[0]
    reps = last_hidden_state[torch.arange(bs, device=last_hidden_state.device), sequence_lengths]
    reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
    return reps
```

### Encode Text Query

```python
queries = ["query: Where can we see Llama?</s>", "query: What is LLaMA model?</s>"]
query_inputs = processor(queries, return_tensors="pt", padding="longest", max_length=128, truncation=True).to('cuda:0')
with torch.no_grad():
    output = model(**query_inputs, return_dict=True, output_hidden_states=True)
query_embeddings = get_embedding(output.hidden_states[-1], query_inputs["attention_mask"])
```

### Encode Document Screenshot

```python
from PIL import Image
import requests
from io import BytesIO

# URLs of the images
url1 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v1/resolve/main/animal-llama.png"
url2 = "https://huggingface.co/Tevatron/dse-phi3-docmatix-v1/resolve/main/meta-llama.png"

# Download and open images
response1 = requests.get(url1)
response2 = requests.get(url2)

passage_image1 = Image.open(BytesIO(response1.content))
passage_image2 = Image.open(BytesIO(response2.content))

passage_images = [passage_image1, passage_image2]
passage_prompts = ["<|image_1|>\nWhat is shown in this image?</s>", "<|image_2|>\nWhat is shown in this image?</s>"]

# Process inputs and get embeddings
passage_inputs = processor(passage_prompts, images=passage_images, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
passage_inputs['input_ids'] = passage_inputs['input_ids'].squeeze(0)
passage_inputs['attention_mask'] = passage_inputs['attention_mask'].squeeze(0)
passage_inputs['image_sizes'] = passage_inputs['image_sizes'].squeeze(0)
with torch.no_grad():
    output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

```

### Compute Similarity

```python
from torch.nn.functional import cosine_similarity
num_queries = query_embeddings.size(0)
num_passages = doc_embeddings.size(0)

for i in range(num_queries):
    query_embedding = query_embeddings[i].unsqueeze(0)
    similarities = cosine_similarity(query_embedding, doc_embeddings)
    print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")
```

### Encode Document Text
This DSE checkpoint is warm-up with `Tevatron/msmarco-passage-aug`, thus the model can also effectively encode document as text input.
```python
passage_prompts = [
  "The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.</s>",
  "Llama (acronym for Large Language Model Meta AI, and formerly stylized as LLaMA) is a family of autoregressive large language models (LLMs) released by Meta AI starting in February 2023.[2][3] The latest version is Llama 3.1, released in July 2024.[4]</s>"
]

passage_inputs = processor(passage_prompts, images=None, return_tensors="pt", padding="longest", max_length=4096, truncation=True).to('cuda:0')
with torch.no_grad():
    output = model(**passage_inputs, return_dict=True, output_hidden_states=True)
doc_embeddings = get_embedding(output.hidden_states[-1], passage_inputs["attention_mask"])

for i in range(num_queries):
    query_embedding = query_embeddings[i].unsqueeze(0)
    similarities = cosine_similarity(query_embedding, doc_embeddings)
    print(f"Similarities for Query {i+1}: {similarities.cpu().float().numpy()}")
```

### Citation
If you find this checkpoint is helpful, please consider cite Phi3, Docmatix and our DSE work.