Update README.md
Browse files
README.md
CHANGED
@@ -56,8 +56,8 @@ import torch
|
|
56 |
|
57 |
device = 'cuda:0'
|
58 |
|
59 |
-
|
60 |
-
|
61 |
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
62 |
if left_padding:
|
63 |
return last_hidden_states[:, -1]
|
@@ -66,30 +66,36 @@ def last_token_pool(last_hidden_states: Tensor,
|
|
66 |
batch_size = last_hidden_states.shape[0]
|
67 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
68 |
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
image_2 = Image.open('/local/path/to/document2.png').convert('RGB')
|
75 |
|
|
|
76 |
query_instruction = 'Represent this query for retrieving relavant document: '
|
77 |
-
|
78 |
query = 'Who was elected as president of United States in 2020?'
|
79 |
-
|
80 |
query_full = query_instruction + query
|
81 |
|
82 |
-
# Embed text queries
|
83 |
-
q_outputs = model(text=[query_full], image=[None, None], tokenizer=tokenizer) # [B, s, d]
|
84 |
-
q_reps = last_token_pool(q_outputs.last_hidden_state, q_outputs.attention_mask) # [B, d]
|
85 |
-
|
86 |
# Embed image documents
|
87 |
-
|
88 |
-
|
|
|
89 |
|
90 |
-
#
|
91 |
-
|
|
|
|
|
92 |
|
|
|
|
|
93 |
print(scores)
|
94 |
|
|
|
|
|
95 |
```
|
|
|
56 |
|
57 |
device = 'cuda:0'
|
58 |
|
59 |
+
# This function is borrowed from https://huggingface.co/intfloat/e5-mistral-7b-instruct
|
60 |
+
def last_token_pool(last_hidden_states, attention_mask):
|
61 |
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
62 |
if left_padding:
|
63 |
return last_hidden_states[:, -1]
|
|
|
66 |
batch_size = last_hidden_states.shape[0]
|
67 |
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
68 |
|
69 |
+
# Load model, be sure to substitute `model_path` by your model path
|
70 |
+
model_path = '/local/path/to/model'
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
72 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
73 |
+
model.to(device)
|
74 |
|
75 |
+
# Load image to PIL.Image object
|
76 |
+
image_1 = Image.open('/local/path/to/images/memex.png').convert('RGB')
|
77 |
+
image_2 = Image.open('/local/path/to/images/us2020.png').convert('RGB')
|
78 |
+
image_3 = Image.open('/local/path/to/images/hard_negative.png').convert('RGB')
|
|
|
79 |
|
80 |
+
# User query
|
81 |
query_instruction = 'Represent this query for retrieving relavant document: '
|
|
|
82 |
query = 'Who was elected as president of United States in 2020?'
|
|
|
83 |
query_full = query_instruction + query
|
84 |
|
|
|
|
|
|
|
|
|
85 |
# Embed image documents
|
86 |
+
with torch.no_grad():
|
87 |
+
p_outputs = model(text=['', '', ''], image=[image_1, image_2, image_3], tokenizer=tokenizer)
|
88 |
+
p_reps = last_token_pool(p_outputs.last_hidden_state, p_outputs.attention_mask)
|
89 |
|
90 |
+
# Embed text queries
|
91 |
+
with torch.no_grad():
|
92 |
+
q_outputs = model(text=[query_full], image=[None], tokenizer=tokenizer) # [B, s, d]
|
93 |
+
q_reps = last_token_pool(q_outputs.last_hidden_state, q_outputs.attention_mask) # [B, d]
|
94 |
|
95 |
+
# Calculate similarities
|
96 |
+
scores = torch.matmul(q_reps, p_reps.T)
|
97 |
print(scores)
|
98 |
|
99 |
+
# tensor([[0.6506, 4.9630, 3.8614]], device='cuda:0')
|
100 |
+
|
101 |
```
|