File size: 3,184 Bytes
e173a84
 
 
 
 
 
 
 
 
848a638
 
e173a84
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
bb5ec6d
e173a84
bb5ec6d
 
95cdb44
bb5ec6d
 
e173a84
 
242b67f
 
c89095f
e173a84
 
 
 
bb5ec6d
 
e173a84
 
 
2a6db88
e173a84
242b67f
848a638
 
194e5e6
848a638
 
 
242b67f
848a638
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
3076001
 
928efab
 
 
 
ed9d473
 
 
 
 
 
 
 
928efab
 
ed9d473
928efab
 
c89095f
928efab
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
from datetime import datetime



model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url, resize = False, fix_height = 200):
    data = requests.get(url, allow_redirects = True).content
    img = Image.open(io.BytesIO(data))
    if resize:
         img.thumbnail([fix_height, fix_height], Image.LANCZOS)
    return  img

def find_topk(text):

    print('text', text)

    top_k = 10

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    print('Got features', datetime.now().strftime("%H:%M:%S"))

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)
    top_k_urls = img_df.iloc[inds]['photo_image_url'].values

    print('top_k_urls', top_k_urls)
    print(datetime.now().strftime("%H:%M:%S"))

    images = [url2img(url, resize = False) for url in top_k_urls]

    print('got PIL images')
    print(datetime.now().strftime("%H:%M:%S"))

    return images



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')

print('version', gr.__version__)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown('# Enter a prompt in one of the supported languages.')
        with gr.Row():
            gr.Markdown('| Code     | Lang     | #    | Code     | Lang                 |\n' 
                        '| :------- | :------- | :--- | :------- | :------------------- |\n'
                        '| eng_Latn | English  | #    | fra_Latn | French               |\n'
                        '| deu_Latn | German   | #    | ita_Latn | Italian              |\n'
                        '| ita_Latn | Spanish  | #    | jpn_Jpan | Japanese             |\n'
                        '| tur_Latn | Turkish  | #    | zho_Hans | Chinese (Simplified) |\n'
                        '| kor_Hang | Korean   | #    | pol_Latn | Polish               |\n'
                        '| rus_Cyrl | Russian  | #    | .        | .                    |\n')

            with gr.Column():
                prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
                btn_search = gr.Button("Find images")

        gallery = gr.Gallery().style(columns=5, rows=2, object_fit="contain", height="auto")
        btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)

if __name__ == "__main__":
    demo.launch()