|
import streamlit as st |
|
from lib.utils.model import get_model, get_similarities |
|
from PIL import Image |
|
|
|
st.title('IRRA Text-To-Image-Retrival') |
|
|
|
st.header('Inputs') |
|
caption = st.text_input('Description Input') |
|
|
|
images = st.file_uploader('Upload images', accept_multiple_files=True) |
|
if images is not None: |
|
st.image(images) |
|
|
|
st.header('Options') |
|
st.subheader('Ranks') |
|
|
|
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5) |
|
|
|
button = st.button('Match most similar', disabled=len(images) == 0 or caption == '') |
|
|
|
if button: |
|
st.header('Results') |
|
with st.spinner('Loading model'): |
|
model = get_model() |
|
|
|
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters') |
|
|
|
with st.spinner('Computing and ranking similarities'): |
|
similarities = get_similarities(caption, images, model) |
|
|
|
indices = similarities.argsort(descending=True).squeeze(0).cpu().tolist()[:ranks] |
|
|
|
for i, idx in enumerate(indices): |
|
c1, c2 = st.columns(2) |
|
with c1: |
|
st.text(f'Rank {i + 1}') |
|
with c2: |
|
st.image(images[idx]) |
|
|
|
|