File size: 1,671 Bytes
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
 
 
697f3c2
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform


model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url):
    data = requests.get(url, allow_redirects = True).content
    #return Image.open(io.BytesIO(data))
    return data

def find_topk(text):

    top_k = 10

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)

    top_k_urls = img_df.iloc[inds]['url'].values[0]

    return url2img(top_k_urls)



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')
demo = gr.Interface(find_topk,
                    inputs = gr.Textbox(label = 'Enter your prompt', lines = 2),
                    outputs = gr.Gallery(),
                    theme = gr.themes.Glass())
if __name__ == "__main__":
    demo.launch()