feat: add cosine similarity measure
Browse files- app.py +7 -5
- lib/utils/model.py +0 -1
app.py
CHANGED
@@ -3,6 +3,7 @@ from lib.utils.model import get_model, get_similarities
|
|
3 |
from PIL import Image
|
4 |
|
5 |
st.title('IRRA Text-To-Image-Retrival')
|
|
|
6 |
|
7 |
st.header('Inputs')
|
8 |
caption = st.text_input('Description Input')
|
@@ -12,7 +13,7 @@ if images is not None:
|
|
12 |
st.image(images) # type: ignore
|
13 |
|
14 |
st.header('Options')
|
15 |
-
st.subheader('Ranks')
|
16 |
|
17 |
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
|
18 |
|
@@ -26,15 +27,16 @@ if button:
|
|
26 |
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
27 |
|
28 |
with st.spinner('Computing and ranking similarities'):
|
29 |
-
similarities = get_similarities(caption, images, model)
|
30 |
|
31 |
-
indices = similarities.argsort(descending=True).
|
32 |
|
33 |
for i, idx in enumerate(indices):
|
34 |
-
c1, c2 = st.columns(
|
35 |
with c1:
|
36 |
st.text(f'Rank {i + 1}')
|
37 |
with c2:
|
38 |
st.image(images[idx])
|
39 |
-
|
|
|
40 |
|
|
|
3 |
from PIL import Image
|
4 |
|
5 |
st.title('IRRA Text-To-Image-Retrival')
|
6 |
+
st.markdown('A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
7 |
|
8 |
st.header('Inputs')
|
9 |
caption = st.text_input('Description Input')
|
|
|
13 |
st.image(images) # type: ignore
|
14 |
|
15 |
st.header('Options')
|
16 |
+
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
17 |
|
18 |
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
|
19 |
|
|
|
27 |
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
28 |
|
29 |
with st.spinner('Computing and ranking similarities'):
|
30 |
+
similarities = get_similarities(caption, images, model).squeeze(0)
|
31 |
|
32 |
+
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
33 |
|
34 |
for i, idx in enumerate(indices):
|
35 |
+
c1, c2, c3 = st.columns(3)
|
36 |
with c1:
|
37 |
st.text(f'Rank {i + 1}')
|
38 |
with c2:
|
39 |
st.image(images[idx])
|
40 |
+
with c3:
|
41 |
+
st.text(f'Cosine sim {similarities[idx].cpu():.2f}')
|
42 |
|
lib/utils/model.py
CHANGED
@@ -24,7 +24,6 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
|
24 |
txt = tokenize(text, tokenizer)
|
25 |
imgs = prepare_images(images)
|
26 |
|
27 |
-
print(imgs.shape)
|
28 |
image_feats = model.encode_image(imgs)
|
29 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
30 |
|
|
|
24 |
txt = tokenize(text, tokenizer)
|
25 |
imgs = prepare_images(images)
|
26 |
|
|
|
27 |
image_feats = model.encode_image(imgs)
|
28 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
29 |
|