Spaces:
Runtime error
Runtime error
File size: 2,744 Bytes
e173a84 a24abaa e173a84 a24abaa e173a84 928efab e173a84 242b67f e173a84 800542b e173a84 242b67f 800542b e173a84 a24abaa 3076001 928efab ed9d473 928efab ed9d473 928efab e173a84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')
embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)
#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)
img_df = pd.read_csv('image_data.csv')
def url2img(url):
data = requests.get(url, allow_redirects = True).content
return Image.open(io.BytesIO(data))
return data
def find_topk(text):
print('text', text)
top_k = 10
text_data = model_multi.preprocess_text(text)
text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
sims = F.cosine_similarity(text_embedding, embeddings)
vals, inds = sims.topk(top_k)
top_k_urls = img_df.iloc[inds]['url'].values
print('top_k_urls', top_k_urls)
return [url2img(url) for url in top_k_urls]
# def rerank(text_features, text_data):
# # craet joint embeddings & get scores
# joint_embedding = model_multi.encode_multimodal(
# image_features=image_features,
# text_features=text_features,
# attention_mask=text_data['attention_mask']
# )
# score = model_multi.get_matching_scores(joint_embedding)
# # argmax to get top N
# return
#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')
print('version', gr.__version__)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown('# Enter a prompt in one of the supported languages.')
with gr.Row():
gr.Markdown('| Code | Lang | # | Code | Lang |\n'
'| :------- | :------- | :--- | :------- | :------------------- |\n'
'| eng_Latn | English | # | fra_Latn | French |\n'
'| deu_Latn | German | # | ita_Latn | Italian |\n'
'| ita_Latn | Spanish | # | jpn_Jpan | Japanese |\n'
'| tur_Latn | Turkish | # | zho_Hans | Chinese (Simplified) |\n'
'| kor_Hang | Korean | # | pol_Latn | Polish |\n'
'| rus_Cyrl | Russian | # | . | . |\n')
with gr.Column():
prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
btn_search = gr.Button("Find images")
gallery = gr.Gallery()
btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)
if __name__ == "__main__":
demo.launch() |