Spaces:
Build error
Build error
features: overall ui cleanup
Browse files- app.py +2 -2
- text2image.py +17 -13
- most_relevant_part.py → text2patch.py +34 -31
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
import image2text
|
4 |
-
import most_relevant_part
|
5 |
import text2image
|
|
|
6 |
|
7 |
PAGES = {
|
8 |
"Text to Image": text2image,
|
9 |
"Image to Text": image2text,
|
10 |
-
"
|
11 |
}
|
12 |
|
13 |
st.sidebar.title("Navigation")
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
import image2text
|
|
|
4 |
import text2image
|
5 |
+
import text2patch
|
6 |
|
7 |
PAGES = {
|
8 |
"Text to Image": text2image,
|
9 |
"Image to Text": image2text,
|
10 |
+
"Patch Importance Ranking": text2patch,
|
11 |
}
|
12 |
|
13 |
st.sidebar.title("Navigation")
|
text2image.py
CHANGED
@@ -33,16 +33,20 @@ def app(model_name):
|
|
33 |
|
34 |
query = st.text_input("한글 질문을 적어주세요 (Korean Text Query) :", value="컴퓨터하는 고양이")
|
35 |
if st.button("질문 (Query)"):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
query = st.text_input("한글 질문을 적어주세요 (Korean Text Query) :", value="컴퓨터하는 고양이")
|
35 |
if st.button("질문 (Query)"):
|
36 |
+
st.markdown("""---""")
|
37 |
+
with st.spinner("Computing..."):
|
38 |
+
proc = processor(
|
39 |
+
text=[query], images=None, return_tensors="jax", padding=True
|
40 |
+
)
|
41 |
+
vec = np.asarray(model.get_text_features(**proc))
|
42 |
+
ids, dists = index.knnQuery(vec, k=10)
|
43 |
+
result_files = map(lambda id: files[id], ids)
|
44 |
+
result_imgs, result_captions = [], []
|
45 |
+
for file, dist in zip(result_files, dists):
|
46 |
+
result_imgs.append(plt.imread(os.path.join(images_directory, file)))
|
47 |
+
result_captions.append("Score: {:.3f}".format(1.0 - dist))
|
48 |
+
|
49 |
+
st.image(result_imgs[:3], caption=result_captions[:3], width=200)
|
50 |
+
st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
|
51 |
+
st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
|
52 |
+
st.image(result_imgs[9:], caption=result_captions[9:], width=200)
|
most_relevant_part.py → text2patch.py
RENAMED
@@ -22,21 +22,10 @@ def split_image(im, num_rows=3, num_cols=3):
|
|
22 |
return tiles
|
23 |
|
24 |
|
25 |
-
# def split_image(X):
|
26 |
-
# num_rows = X.shape[0] // 224
|
27 |
-
# num_cols = X.shape[1] // 224
|
28 |
-
# Xc = X[0:num_rows * 224, 0:num_cols * 224, :]
|
29 |
-
# patches = []
|
30 |
-
# for j in range(num_rows):
|
31 |
-
# for i in range(num_cols):
|
32 |
-
# patches.append(Xc[j * 224:(j + 1) * 224, i * 224:(i + 1) * 224, :])
|
33 |
-
# return patches
|
34 |
-
|
35 |
-
|
36 |
def app(model_name):
|
37 |
model, processor = load_model(f"koclip/{model_name}")
|
38 |
|
39 |
-
st.title("
|
40 |
st.markdown(
|
41 |
"""
|
42 |
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
|
@@ -60,29 +49,43 @@ def app(model_name):
|
|
60 |
"Enter query to find most relevant part of image ",
|
61 |
value="이건 서울의 경복궁 사진이다.",
|
62 |
)
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
if st.button("질문 (Query)"):
|
67 |
if not any([query1, query2]):
|
68 |
st.error("Please upload an image or paste an image URL.")
|
69 |
else:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
22 |
return tiles
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def app(model_name):
|
26 |
model, processor = load_model(f"koclip/{model_name}")
|
27 |
|
28 |
+
st.title("Patch-based Relevance Retrieval")
|
29 |
st.markdown(
|
30 |
"""
|
31 |
Given a piece of text, the CLIP model finds the part of an image that best explains the text.
|
|
|
49 |
"Enter query to find most relevant part of image ",
|
50 |
value="이건 서울의 경복궁 사진이다.",
|
51 |
)
|
52 |
+
|
53 |
+
col1, col2 = st.beta_columns(2)
|
54 |
+
with col1:
|
55 |
+
num_rows = st.slider(
|
56 |
+
"Number of rows", min_value=1, max_value=5, value=3, step=1
|
57 |
+
)
|
58 |
+
with col2:
|
59 |
+
num_cols = st.slider(
|
60 |
+
"Number of columns", min_value=1, max_value=5, value=3, step=1
|
61 |
+
)
|
62 |
|
63 |
if st.button("질문 (Query)"):
|
64 |
if not any([query1, query2]):
|
65 |
st.error("Please upload an image or paste an image URL.")
|
66 |
else:
|
67 |
+
st.markdown("""---""")
|
68 |
+
with st.spinner("Computing..."):
|
69 |
+
image_data = (
|
70 |
+
query2
|
71 |
+
if query2 is not None
|
72 |
+
else requests.get(query1, stream=True).raw
|
73 |
+
)
|
74 |
+
image = Image.open(image_data)
|
75 |
+
st.image(image)
|
76 |
|
77 |
+
images = split_image(image, num_rows, num_cols)
|
78 |
|
79 |
+
inputs = processor(
|
80 |
+
text=captions, images=images, return_tensors="jax", padding=True
|
81 |
+
)
|
82 |
+
inputs["pixel_values"] = jnp.transpose(
|
83 |
+
inputs["pixel_values"], axes=[0, 2, 3, 1]
|
84 |
+
)
|
85 |
+
outputs = model(**inputs)
|
86 |
+
probs = jax.nn.softmax(outputs.logits_per_image, axis=0)
|
87 |
+
for idx, prob in sorted(
|
88 |
+
enumerate(probs), key=lambda x: x[1], reverse=True
|
89 |
+
):
|
90 |
+
st.text(f"Score: {prob[0]:.3f}")
|
91 |
+
st.image(images[idx])
|