import io import requests import numpy as np import pandas as pd import torch import torch.nn.functional as F from PIL import Image import gradio as gr import uform from datetime import datetime model_multi = uform.get_model('unum-cloud/uform-vl-multilingual') embeddings = np.load('tensors/embeddings.npy') embeddings = torch.tensor(embeddings) #features = np.load('multilingual-image-search/tensors/features.npy') #features = torch.tensor(features) img_df = pd.read_csv('image_data.csv') def url2img(url, resize = False, fix_height = 200): data = requests.get(url, allow_redirects = True).content img = Image.open(io.BytesIO(data)) if resize: width, height = img.size img = img.resize((resize_wh(width, height, fix_height))) return img def resize_wh(w, h, fix_height): ratio = w / h w_n = int(fix_height * ratio) return w_n, fix_height def find_topk(text): print('text', text) top_k = 10 text_data = model_multi.preprocess_text(text) text_features, text_embedding = model_multi.encode_text(text_data, return_features=True) print('Got features', datetime.now().strftime("%H:%M:%S")) sims = F.cosine_similarity(text_embedding, embeddings) vals, inds = sims.topk(top_k) top_k_urls = img_df.iloc[inds]['photo_image_url'].values print('top_k_urls', top_k_urls) print(datetime.now().strftime("%H:%M:%S")) images = [url2img(url) for url in top_k_urls] print('got PIL images') print(datetime.now().strftime("%H:%M:%S")) return images # def rerank(text_features, text_data): # # craet joint embeddings & get scores # joint_embedding = model_multi.encode_multimodal( # image_features=image_features, # text_features=text_features, # attention_mask=text_data['attention_mask'] # ) # score = model_multi.get_matching_scores(joint_embedding) # # argmax to get top N # return #demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image') print('version', gr.__version__) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown('# Enter a prompt in one of the supported languages.') with gr.Row(): gr.Markdown('| Code | Lang | # | Code | Lang |\n' '| :------- | :------- | :--- | :------- | :------------------- |\n' '| eng_Latn | English | # | fra_Latn | French |\n' '| deu_Latn | German | # | ita_Latn | Italian |\n' '| ita_Latn | Spanish | # | jpn_Jpan | Japanese |\n' '| tur_Latn | Turkish | # | zho_Hans | Chinese (Simplified) |\n' '| kor_Hang | Korean | # | pol_Latn | Polish |\n' '| rus_Cyrl | Russian | # | . | . |\n') with gr.Column(): prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3) btn_search = gr.Button("Find images") gallery = gr.Gallery() btn_search.click(find_topk, inputs = prompt_box, outputs = gallery) if __name__ == "__main__": demo.launch()