Trent commited on
Commit
cf349fd
โ€ข
1 Parent(s): 2cd1913

Text to image Search Engine demo

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. requirements.txt +3 -1
  3. text2image.py +34 -4
  4. utils.py +17 -0
app.py CHANGED
@@ -10,4 +10,4 @@ st.sidebar.title("Navigation")
10
  model = st.sidebar.selectbox("Choose a model", ["koclip", "koclip-large"])
11
  page = st.sidebar.selectbox("Choose a task", list(PAGES.keys()))
12
 
13
- PAGES[page].app(f"koclip/{model}")
 
10
  model = st.sidebar.selectbox("Choose a model", ["koclip", "koclip-large"])
11
  page = st.sidebar.selectbox("Choose a task", list(PAGES.keys()))
12
 
13
+ PAGES[page].app(model)
requirements.txt CHANGED
@@ -3,4 +3,6 @@ jaxlib
3
  flax
4
  transformers
5
  streamlit
6
- tqdm
 
 
 
3
  flax
4
  transformers
5
  streamlit
6
+ tqdm
7
+ nmslib
8
+ matplotlib
text2image.py CHANGED
@@ -1,14 +1,44 @@
 
 
1
  import streamlit as st
2
 
3
- from utils import load_model
 
 
4
 
5
 
6
  def app(model_name):
7
- model, processor = load_model(model_name)
 
8
 
 
 
9
 
10
- st.title("Text to Image")
11
  st.markdown("""
12
- Some text goes in here.
 
 
 
 
 
 
 
 
13
  """)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  import streamlit as st
4
 
5
+ from utils import load_model, load_index
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
 
9
 
10
  def app(model_name):
11
+ images_directory = 'images/val2017'
12
+ features_directory = f'features/val2017/{model_name}.tsv'
13
 
14
+ files, index = load_index(features_directory)
15
+ model, processor = load_model(f'koclip/{model_name}')
16
 
17
+ st.title("Text to Image Search Engine")
18
  st.markdown("""
19
+ This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
20
+ 5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
21
+ vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
22
+ are displayed below.
23
+
24
+ KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and
25
+ Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence).
26
+
27
+ Example Queries : ์•„ํŒŒํŠธ(Apartment), ์ž๋™์ฐจ(Car), ์ปดํ“จํ„ฐ(Computer)
28
  """)
29
 
30
+ query = st.text_input("ํ•œ๊ธ€ ์งˆ๋ฌธ์„ ์ ์–ด์ฃผ์„ธ์š” (Korean Text Query) :", value="์•„ํŒŒํŠธ")
31
+ if st.button("์งˆ๋ฌธ (Query)"):
32
+ proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
33
+ vec = np.asarray(model.get_text_features(**proc))
34
+ ids, dists = index.knnQuery(vec, k=10)
35
+ result_files = map(lambda id: files[id], ids)
36
+ result_imgs, result_captions = [], []
37
+ for file, dist in zip(result_files, dists):
38
+ result_imgs.append(plt.imread(os.path.join(images_directory, file)))
39
+ result_captions.append("{:s} (์œ ์‚ฌ๋„: {:.3f})".format(file, 1.0 - dist))
40
+
41
+ st.image(result_imgs[:3], caption=result_captions[:3], width=200)
42
+ st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
43
+ st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
44
+ st.image(result_imgs[9:], caption=result_captions[9:], width=200)
utils.py CHANGED
@@ -1,8 +1,25 @@
 
1
  import streamlit as st
2
  from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
 
3
 
4
  from koclip import FlaxHybridCLIP
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  @st.cache(allow_output_mutation=True)
8
  def load_model(model_name="koclip/koclip"):
 
1
+ import nmslib
2
  import streamlit as st
3
  from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
4
+ import numpy as np
5
 
6
  from koclip import FlaxHybridCLIP
7
 
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_index(img_file):
10
+ filenames, embeddings = [], []
11
+ lines = open(img_file, "r")
12
+ for line in lines:
13
+ cols = line.strip().split('\t')
14
+ filename = cols[0]
15
+ embedding = np.array([float(x) for x in cols[1].split(',')])
16
+ filenames.append(filename)
17
+ embeddings.append(embedding)
18
+ embeddings = np.array(embeddings)
19
+ index = nmslib.init(method='hnsw', space='cosinesimil')
20
+ index.addDataPointBatch(embeddings)
21
+ index.createIndex({'post': 2}, print_progress=True)
22
+ return filenames, index
23
 
24
  @st.cache(allow_output_mutation=True)
25
  def load_model(model_name="koclip/koclip"):