File size: 3,051 Bytes
e173a84
 
 
 
 
 
 
 
 
848a638
 
e173a84
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
4a3a204
e173a84
bb5ec6d
 
95cdb44
bb5ec6d
 
e173a84
 
242b67f
 
4a3a204
e173a84
 
 
 
bb5ec6d
 
e173a84
 
 
2a6db88
e173a84
4a3a204
848a638
 
4a3a204
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
3076001
 
928efab
 
 
 
ed9d473
 
 
 
 
 
 
 
928efab
 
ed9d473
928efab
 
4a3a204
928efab
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
from datetime import datetime



model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url, resize = False, fix_height = 150):
    data = requests.get(url, allow_redirects = True).content
    img = Image.open(io.BytesIO(data))
    if resize:
         img.thumbnail([fix_height, fix_height], Image.LANCZOS)
    return  img

def find_topk(text):

    print('text', text)

    top_k = 20

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    print('Got features', datetime.now().strftime("%H:%M:%S"))

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)
    top_k_urls = img_df.iloc[inds]['photo_image_url'].values

    print('Got top_k_urls', top_k_urls)
    print(datetime.now().strftime("%H:%M:%S"))

    return top_k_urls



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')

print('version', gr.__version__)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown('# Enter a prompt in one of the supported languages.')
        with gr.Row():
            gr.Markdown('| Code     | Lang     | #    | Code     | Lang                 |\n' 
                        '| :------- | :------- | :--- | :------- | :------------------- |\n'
                        '| eng_Latn | English  | #    | fra_Latn | French               |\n'
                        '| deu_Latn | German   | #    | ita_Latn | Italian              |\n'
                        '| ita_Latn | Spanish  | #    | jpn_Jpan | Japanese             |\n'
                        '| tur_Latn | Turkish  | #    | zho_Hans | Chinese (Simplified) |\n'
                        '| kor_Hang | Korean   | #    | pol_Latn | Polish               |\n'
                        '| rus_Cyrl | Russian  | #    | .        | .                    |\n')

            with gr.Column():
                prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
                btn_search = gr.Button("Find images")

        gallery = gr.Gallery().style(columns = [5],  height="auto", object_fit = "scale-down")
        btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)

if __name__ == "__main__":
    demo.launch()