Spaces:
Runtime error
Runtime error
import io | |
import requests | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import gradio as gr | |
import uform | |
model_multi = uform.get_model('unum-cloud/uform-vl-multilingual') | |
embeddings = np.load('tensors/embeddings.npy') | |
embeddings = torch.tensor(embeddings) | |
#features = np.load('multilingual-image-search/tensors/features.npy') | |
#features = torch.tensor(features) | |
img_df = pd.read_csv('image_data.csv') | |
def url2img(url): | |
data = requests.get(url, allow_redirects = True).content | |
# return Image.open(io.BytesIO(data)) | |
return np.array(io.BytesIO(data)) | |
def find_topk(text): | |
print('text', text) | |
top_k = 10 | |
text_data = model_multi.preprocess_text(text) | |
text_features, text_embedding = model_multi.encode_text(text_data, return_features=True) | |
sims = F.cosine_similarity(text_embedding, embeddings) | |
vals, inds = sims.topk(top_k) | |
top_k_urls = img_df.iloc[inds]['url'].values | |
print('top_k_urls', top_k_urls) | |
return [url2img(url) for url in top_k_urls] | |
# def rerank(text_features, text_data): | |
# # craet joint embeddings & get scores | |
# joint_embedding = model_multi.encode_multimodal( | |
# image_features=image_features, | |
# text_features=text_features, | |
# attention_mask=text_data['attention_mask'] | |
# ) | |
# score = model_multi.get_matching_scores(joint_embedding) | |
# # argmax to get top N | |
# return | |
#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image') | |
print('version', gr.__version__) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown('# Enter a prompt in one of the supported languages.') | |
with gr.Row(): | |
gr.Markdown('| Code | Lang | # | Code | Lang |\n' | |
'| :------- | :------- | :--- | :------- | :------------------- |\n' | |
'| eng_Latn | English | # | fra_Latn | French |\n' | |
'| deu_Latn | German | # | ita_Latn | Italian |\n' | |
'| ita_Latn | Spanish | # | jpn_Jpan | Japanese |\n' | |
'| tur_Latn | Turkish | # | zho_Hans | Chinese (Simplified) |\n' | |
'| kor_Hang | Korean | # | pol_Latn | Polish |\n' | |
'| rus_Cyrl | Russian | # | . | . |\n') | |
with gr.Column(): | |
prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3) | |
btn_search = gr.Button("Find images") | |
gallery = gr.Gallery() | |
btn_search.click(find_topk, inputs = prompt_box, outputs = gallery) | |
if __name__ == "__main__": | |
demo.launch() |