marinap's picture
get gradio version
3076001
raw
history blame
1.76 kB
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')
embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)
#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)
img_df = pd.read_csv('image_data.csv')
def url2img(url):
data = requests.get(url, allow_redirects = True).content
#return Image.open(io.BytesIO(data))
return data
def find_topk(text):
print('text', text)
top_k = 10
text_data = model_multi.preprocess_text(text)
text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
sims = F.cosine_similarity(text_embedding, embeddings)
vals, inds = sims.topk(top_k)
top_k_urls = img_df.iloc[inds]['url'].values
print('top_k_urls', top_k_urls)
return [url2img(url) for url in top_k_urls]
# def rerank(text_features, text_data):
# # craet joint embeddings & get scores
# joint_embedding = model_multi.encode_multimodal(
# image_features=image_features,
# text_features=text_features,
# attention_mask=text_data['attention_mask']
# )
# score = model_multi.get_matching_scores(joint_embedding)
# # argmax to get top N
# return
#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')
print('version', gr.__version__)
demo = gr.Interface(fn = find_topk,
inputs = gr.Textbox(label = 'Enter your prompt', lines = 2),
outputs = gr.Gallery(),
)
if __name__ == "__main__":
demo.launch()