File size: 2,728 Bytes
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
 
 
f9ad361
e173a84
 
 
242b67f
 
e173a84
 
 
 
 
 
 
 
 
800542b
e173a84
242b67f
 
800542b
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
3076001
 
928efab
 
 
 
ed9d473
 
 
 
 
 
 
 
928efab
 
ed9d473
928efab
 
 
 
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform


model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url):
    data = requests.get(url, allow_redirects = True).content
    return Image.open(io.BytesIO(data))

def find_topk(text):

    print('text', text)

    top_k = 10

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)

    top_k_urls = img_df.iloc[inds]['url'].values

    print('top_k_urls', top_k_urls)

    return [url2img(url) for url in top_k_urls]



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')

print('version', gr.__version__)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown('# Enter a prompt in one of the supported languages.')
        with gr.Row():
            gr.Markdown('| Code     | Lang     | #    | Code     | Lang                 |\n' 
                        '| :------- | :------- | :--- | :------- | :------------------- |\n'
                        '| eng_Latn | English  | #    | fra_Latn | French               |\n'
                        '| deu_Latn | German   | #    | ita_Latn | Italian              |\n'
                        '| ita_Latn | Spanish  | #    | jpn_Jpan | Japanese             |\n'
                        '| tur_Latn | Turkish  | #    | zho_Hans | Chinese (Simplified) |\n'
                        '| kor_Hang | Korean   | #    | pol_Latn | Polish               |\n'
                        '| rus_Cyrl | Russian  | #    | .        | .                    |\n')

            with gr.Column():
                prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
                btn_search = gr.Button("Find images")

        gallery = gr.Gallery()
        btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)

if __name__ == "__main__":
    demo.launch()