jaketae commited on
Commit
48a1fa8
1 Parent(s): bf9c2d9

features: overall ui cleanup

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. text2image.py +17 -13
  3. most_relevant_part.py → text2patch.py +34 -31
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import streamlit as st
2
 
3
  import image2text
4
- import most_relevant_part
5
  import text2image
 
6
 
7
  PAGES = {
8
  "Text to Image": text2image,
9
  "Image to Text": image2text,
10
- "Most Relevant Part of Image": most_relevant_part,
11
  }
12
 
13
  st.sidebar.title("Navigation")
 
1
  import streamlit as st
2
 
3
  import image2text
 
4
  import text2image
5
+ import text2patch
6
 
7
  PAGES = {
8
  "Text to Image": text2image,
9
  "Image to Text": image2text,
10
+ "Patch Importance Ranking": text2patch,
11
  }
12
 
13
  st.sidebar.title("Navigation")
text2image.py CHANGED
@@ -33,16 +33,20 @@ def app(model_name):
33
 
34
  query = st.text_input("한글 질문을 적어주세요 (Korean Text Query) :", value="컴퓨터하는 고양이")
35
  if st.button("질문 (Query)"):
36
- proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
37
- vec = np.asarray(model.get_text_features(**proc))
38
- ids, dists = index.knnQuery(vec, k=10)
39
- result_files = map(lambda id: files[id], ids)
40
- result_imgs, result_captions = [], []
41
- for file, dist in zip(result_files, dists):
42
- result_imgs.append(plt.imread(os.path.join(images_directory, file)))
43
- result_captions.append("Score: {:.3f}".format(1.0 - dist))
44
-
45
- st.image(result_imgs[:3], caption=result_captions[:3], width=200)
46
- st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
47
- st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
48
- st.image(result_imgs[9:], caption=result_captions[9:], width=200)
 
 
 
 
 
33
 
34
  query = st.text_input("한글 질문을 적어주세요 (Korean Text Query) :", value="컴퓨터하는 고양이")
35
  if st.button("질문 (Query)"):
36
+ st.markdown("""---""")
37
+ with st.spinner("Computing..."):
38
+ proc = processor(
39
+ text=[query], images=None, return_tensors="jax", padding=True
40
+ )
41
+ vec = np.asarray(model.get_text_features(**proc))
42
+ ids, dists = index.knnQuery(vec, k=10)
43
+ result_files = map(lambda id: files[id], ids)
44
+ result_imgs, result_captions = [], []
45
+ for file, dist in zip(result_files, dists):
46
+ result_imgs.append(plt.imread(os.path.join(images_directory, file)))
47
+ result_captions.append("Score: {:.3f}".format(1.0 - dist))
48
+
49
+ st.image(result_imgs[:3], caption=result_captions[:3], width=200)
50
+ st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
51
+ st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
52
+ st.image(result_imgs[9:], caption=result_captions[9:], width=200)
most_relevant_part.py → text2patch.py RENAMED
@@ -22,21 +22,10 @@ def split_image(im, num_rows=3, num_cols=3):
22
  return tiles
23
 
24
 
25
- # def split_image(X):
26
- # num_rows = X.shape[0] // 224
27
- # num_cols = X.shape[1] // 224
28
- # Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
29
- # patches = []
30
- # for j in range(num_rows):
31
- # for i in range(num_cols):
32
- # patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
33
- # return patches
34
-
35
-
36
  def app(model_name):
37
  model, processor = load_model(f"koclip/{model_name}")
38
 
39
- st.title("Most Relevant Part of Image")
40
  st.markdown(
41
  """
42
  Given a piece of text, the CLIP model finds the part of an image that best explains the text.
@@ -60,29 +49,43 @@ def app(model_name):
60
  "Enter query to find most relevant part of image ",
61
  value="이건 서울의 경복궁 사진이다.",
62
  )
63
- num_rows = st.slider("Number of rows", min_value=1, max_value=5, value=3, step=1)
64
- num_cols = st.slider("Number of columns", min_value=1, max_value=5, value=3, step=1)
 
 
 
 
 
 
 
 
65
 
66
  if st.button("질문 (Query)"):
67
  if not any([query1, query2]):
68
  st.error("Please upload an image or paste an image URL.")
69
  else:
70
- image_data = (
71
- query2 if query2 is not None else requests.get(query1, stream=True).raw
72
- )
73
- image = Image.open(image_data)
74
- st.image(image)
 
 
 
 
75
 
76
- images = split_image(image, num_rows, num_cols)
77
 
78
- inputs = processor(
79
- text=captions, images=images, return_tensors="jax", padding=True
80
- )
81
- inputs["pixel_values"] = jnp.transpose(
82
- inputs["pixel_values"], axes=[0, 2, 3, 1]
83
- )
84
- outputs = model(**inputs)
85
- probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
86
- for idx, prob in sorted(enumerate(probs), key=lambda x: x[1], reverse=True):
87
- st.text(f"Score: {prob[0]:.3f}")
88
- st.image(images[idx])
 
 
 
22
  return tiles
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  def app(model_name):
26
  model, processor = load_model(f"koclip/{model_name}")
27
 
28
+ st.title("Patch-based Relevance Retrieval")
29
  st.markdown(
30
  """
31
  Given a piece of text, the CLIP model finds the part of an image that best explains the text.
 
49
  "Enter query to find most relevant part of image ",
50
  value="이건 서울의 경복궁 사진이다.",
51
  )
52
+
53
+ col1, col2 = st.beta_columns(2)
54
+ with col1:
55
+ num_rows = st.slider(
56
+ "Number of rows", min_value=1, max_value=5, value=3, step=1
57
+ )
58
+ with col2:
59
+ num_cols = st.slider(
60
+ "Number of columns", min_value=1, max_value=5, value=3, step=1
61
+ )
62
 
63
  if st.button("질문 (Query)"):
64
  if not any([query1, query2]):
65
  st.error("Please upload an image or paste an image URL.")
66
  else:
67
+ st.markdown("""---""")
68
+ with st.spinner("Computing..."):
69
+ image_data = (
70
+ query2
71
+ if query2 is not None
72
+ else requests.get(query1, stream=True).raw
73
+ )
74
+ image = Image.open(image_data)
75
+ st.image(image)
76
 
77
+ images = split_image(image, num_rows, num_cols)
78
 
79
+ inputs = processor(
80
+ text=captions, images=images, return_tensors="jax", padding=True
81
+ )
82
+ inputs["pixel_values"] = jnp.transpose(
83
+ inputs["pixel_values"], axes=[0, 2, 3, 1]
84
+ )
85
+ outputs = model(**inputs)
86
+ probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
87
+ for idx, prob in sorted(
88
+ enumerate(probs), key=lambda x: x[1], reverse=True
89
+ ):
90
+ st.text(f"Score: {prob[0]:.3f}")
91
+ st.image(images[idx])