marinap's picture
updated image_data.csv with image URLs, using those in the app too
2a6db88
raw
history blame
2.74 kB
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')
embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)
#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)
img_df = pd.read_csv('image_data.csv')
def url2img(url):
data = requests.get(url, allow_redirects = True).content
return Image.open(io.BytesIO(data))
def find_topk(text):
print('text', text)
top_k = 10
text_data = model_multi.preprocess_text(text)
text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
sims = F.cosine_similarity(text_embedding, embeddings)
vals, inds = sims.topk(top_k)
top_k_urls = img_df.iloc[inds]['photo_image_url'].values
print('top_k_urls', top_k_urls)
return [url2img(url) for url in top_k_urls]
# def rerank(text_features, text_data):
# # craet joint embeddings & get scores
# joint_embedding = model_multi.encode_multimodal(
# image_features=image_features,
# text_features=text_features,
# attention_mask=text_data['attention_mask']
# )
# score = model_multi.get_matching_scores(joint_embedding)
# # argmax to get top N
# return
#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')
print('version', gr.__version__)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown('# Enter a prompt in one of the supported languages.')
with gr.Row():
gr.Markdown('| Code | Lang | # | Code | Lang |\n'
'| :------- | :------- | :--- | :------- | :------------------- |\n'
'| eng_Latn | English | # | fra_Latn | French |\n'
'| deu_Latn | German | # | ita_Latn | Italian |\n'
'| ita_Latn | Spanish | # | jpn_Jpan | Japanese |\n'
'| tur_Latn | Turkish | # | zho_Hans | Chinese (Simplified) |\n'
'| kor_Hang | Korean | # | pol_Latn | Polish |\n'
'| rus_Cyrl | Russian | # | . | . |\n')
with gr.Column():
prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
btn_search = gr.Button("Find images")
gallery = gr.Gallery()
btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)
if __name__ == "__main__":
demo.launch()