File size: 3,253 Bytes
e173a84
 
 
 
 
 
 
 
 
848a638
 
e173a84
 
 
 
a24abaa
e173a84
 
 
 
 
a24abaa
e173a84
bb5ec6d
e173a84
bb5ec6d
 
 
 
 
 
 
 
 
 
e173a84
 
 
242b67f
 
e173a84
 
 
 
 
bb5ec6d
 
e173a84
 
 
2a6db88
e173a84
242b67f
848a638
 
 
 
 
 
242b67f
848a638
e173a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a24abaa
3076001
 
928efab
 
 
 
ed9d473
 
 
 
 
 
 
 
928efab
 
ed9d473
928efab
 
 
 
 
e173a84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import io
import requests
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import uform
from datetime import datetime



model_multi = uform.get_model('unum-cloud/uform-vl-multilingual')

embeddings = np.load('tensors/embeddings.npy')
embeddings = torch.tensor(embeddings)

#features = np.load('multilingual-image-search/tensors/features.npy')
#features = torch.tensor(features)

img_df = pd.read_csv('image_data.csv')

def url2img(url, resize = False, fix_height = 200):
    data = requests.get(url, allow_redirects = True).content
    img = Image.open(io.BytesIO(data))
    if resize:
        width, height = img.size
        img = img.resize((resize_wh(width, height, fix_height)))
    return  img

def resize_wh(w, h, fix_height):
    ratio = w / h
    w_n = int(fix_height * ratio)
    return w_n, fix_height

def find_topk(text):

    print('text', text)

    top_k = 10

    text_data = model_multi.preprocess_text(text)
    text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)

    print('Got features', datetime.now().strftime("%H:%M:%S"))

    sims = F.cosine_similarity(text_embedding, embeddings)

    vals, inds = sims.topk(top_k)
    top_k_urls = img_df.iloc[inds]['photo_image_url'].values

    print('top_k_urls', top_k_urls)
    print(datetime.now().strftime("%H:%M:%S"))

    images = [url2img(url) for url in top_k_urls]

    print('got PIL images')
    print(datetime.now().strftime("%H:%M:%S"))

    return images



# def rerank(text_features, text_data):

#     # craet joint embeddings & get scores
#     joint_embedding = model_multi.encode_multimodal(
#         image_features=image_features,
#         text_features=text_features,
#         attention_mask=text_data['attention_mask']
#     )
#     score = model_multi.get_matching_scores(joint_embedding)

#     # argmax to get top N

#     return


#demo = gr.Interface(find_topk, inputs = 'text', outputs = 'image')

print('version', gr.__version__)

with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown('# Enter a prompt in one of the supported languages.')
        with gr.Row():
            gr.Markdown('| Code     | Lang     | #    | Code     | Lang                 |\n' 
                        '| :------- | :------- | :--- | :------- | :------------------- |\n'
                        '| eng_Latn | English  | #    | fra_Latn | French               |\n'
                        '| deu_Latn | German   | #    | ita_Latn | Italian              |\n'
                        '| ita_Latn | Spanish  | #    | jpn_Jpan | Japanese             |\n'
                        '| tur_Latn | Turkish  | #    | zho_Hans | Chinese (Simplified) |\n'
                        '| kor_Hang | Korean   | #    | pol_Latn | Polish               |\n'
                        '| rus_Cyrl | Russian  | #    | .        | .                    |\n')

            with gr.Column():
                prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
                btn_search = gr.Button("Find images")

        gallery = gr.Gallery()
        btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)

if __name__ == "__main__":
    demo.launch()