marinap commited on
Commit
4a3a204
1 Parent(s): c89095f

improved latency: passing urls to gallery directly

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -21,7 +21,7 @@ embeddings = torch.tensor(embeddings)
21
 
22
  img_df = pd.read_csv('image_data.csv')
23
 
24
- def url2img(url, resize = False, fix_height = 200):
25
  data = requests.get(url, allow_redirects = True).content
26
  img = Image.open(io.BytesIO(data))
27
  if resize:
@@ -32,7 +32,7 @@ def find_topk(text):
32
 
33
  print('text', text)
34
 
35
- top_k = 10
36
 
37
  text_data = model_multi.preprocess_text(text)
38
  text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
@@ -44,15 +44,10 @@ def find_topk(text):
44
  vals, inds = sims.topk(top_k)
45
  top_k_urls = img_df.iloc[inds]['photo_image_url'].values
46
 
47
- print('top_k_urls', top_k_urls)
48
  print(datetime.now().strftime("%H:%M:%S"))
49
 
50
- images = [url2img(url, resize = False) for url in top_k_urls]
51
-
52
- print('got PIL images')
53
- print(datetime.now().strftime("%H:%M:%S"))
54
-
55
- return images
56
 
57
 
58
 
@@ -91,7 +86,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
  prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
92
  btn_search = gr.Button("Find images")
93
 
94
- gallery = gr.Gallery().style(columns=5, rows=2, object_fit="contain", height="auto")
95
  btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)
96
 
97
  if __name__ == "__main__":
 
21
 
22
  img_df = pd.read_csv('image_data.csv')
23
 
24
+ def url2img(url, resize = False, fix_height = 150):
25
  data = requests.get(url, allow_redirects = True).content
26
  img = Image.open(io.BytesIO(data))
27
  if resize:
 
32
 
33
  print('text', text)
34
 
35
+ top_k = 20
36
 
37
  text_data = model_multi.preprocess_text(text)
38
  text_features, text_embedding = model_multi.encode_text(text_data, return_features=True)
 
44
  vals, inds = sims.topk(top_k)
45
  top_k_urls = img_df.iloc[inds]['photo_image_url'].values
46
 
47
+ print('Got top_k_urls', top_k_urls)
48
  print(datetime.now().strftime("%H:%M:%S"))
49
 
50
+ return top_k_urls
 
 
 
 
 
51
 
52
 
53
 
 
86
  prompt_box = gr.Textbox(label = 'Enter your prompt', lines = 3)
87
  btn_search = gr.Button("Find images")
88
 
89
+ gallery = gr.Gallery().style(columns = [5], height="auto", object_fit = "scale-down")
90
  btn_search.click(find_topk, inputs = prompt_box, outputs = gallery)
91
 
92
  if __name__ == "__main__":