Spaces:
Build error
Build error
Sujit Pal
commited on
Commit
·
f58917e
1
Parent(s):
5de821f
fix: changing output format to include caption
Browse files- dashboard_image2image.py +14 -13
- dashboard_text2image.py +14 -16
- utils.py +15 -0
dashboard_image2image.py
CHANGED
@@ -12,11 +12,9 @@ import utils
|
|
12 |
|
13 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
14 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
15 |
-
|
16 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
17 |
-
|
18 |
IMAGES_DIR = "./images"
|
19 |
-
|
20 |
|
21 |
@st.cache(allow_output_mutation=True)
|
22 |
def load_example_images():
|
@@ -62,6 +60,7 @@ def download_and_prepare_image(image_url):
|
|
62 |
def app():
|
63 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
64 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
|
|
65 |
|
66 |
example_image_list = load_example_images()
|
67 |
|
@@ -150,17 +149,19 @@ def app():
|
|
150 |
query_vec = np.asarray(query_vec)
|
151 |
ids, distances = index.knnQuery(query_vec, k=11)
|
152 |
result_filenames = [filenames[id] for id in ids]
|
153 |
-
|
154 |
for result_filename, score in zip(result_filenames, distances):
|
155 |
if image_name is not None and result_filename == image_name:
|
156 |
continue
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
166 |
suggest_idx = -1
|
|
|
12 |
|
13 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
14 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
|
|
15 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
|
16 |
IMAGES_DIR = "./images"
|
17 |
+
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
|
18 |
|
19 |
@st.cache(allow_output_mutation=True)
|
20 |
def load_example_images():
|
|
|
60 |
def app():
|
61 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
62 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
63 |
+
image2caption = utils.load_captions(CAPTIONS_FILE)
|
64 |
|
65 |
example_image_list = load_example_images()
|
66 |
|
|
|
149 |
query_vec = np.asarray(query_vec)
|
150 |
ids, distances = index.knnQuery(query_vec, k=11)
|
151 |
result_filenames = [filenames[id] for id in ids]
|
152 |
+
rank = 0
|
153 |
for result_filename, score in zip(result_filenames, distances):
|
154 |
if image_name is not None and result_filename == image_name:
|
155 |
continue
|
156 |
+
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
|
157 |
+
col1, col2, col3 = st.beta_columns([2, 10, 10])
|
158 |
+
col1.markdown("{:d}.".format(rank + 1))
|
159 |
+
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
|
160 |
+
caption=caption)
|
161 |
+
caption_text = []
|
162 |
+
for caption in image2caption[result_filename]:
|
163 |
+
caption_text.append("* {:s}\n".format(caption))
|
164 |
+
col3.markdown("".join(caption_text))
|
165 |
+
rank += 1
|
166 |
+
st.markdown("---")
|
167 |
suggest_idx = -1
|
dashboard_text2image.py
CHANGED
@@ -4,25 +4,21 @@ import numpy as np
|
|
4 |
import os
|
5 |
import streamlit as st
|
6 |
|
|
|
7 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
8 |
|
9 |
import utils
|
10 |
|
11 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
12 |
-
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
|
13 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
14 |
-
|
15 |
-
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
|
16 |
-
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
17 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
18 |
-
|
19 |
-
# IMAGES_DIR = "/home/shared/data/rsicd_images"
|
20 |
IMAGES_DIR = "./images"
|
21 |
-
|
22 |
|
23 |
def app():
|
24 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
25 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
|
|
26 |
|
27 |
st.title("Retrieve Images given Text")
|
28 |
st.markdown("""
|
@@ -78,13 +74,15 @@ def app():
|
|
78 |
query_vec = np.asarray(query_vec)
|
79 |
ids, distances = index.knnQuery(query_vec, k=10)
|
80 |
result_filenames = [filenames[id] for id in ids]
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
90 |
suggest_idx = -1
|
|
|
4 |
import os
|
5 |
import streamlit as st
|
6 |
|
7 |
+
from PIL import Image
|
8 |
from transformers import CLIPProcessor, FlaxCLIPModel
|
9 |
|
10 |
import utils
|
11 |
|
12 |
BASELINE_MODEL = "openai/clip-vit-base-patch32"
|
|
|
13 |
MODEL_PATH = "flax-community/clip-rsicd-v2"
|
|
|
|
|
|
|
14 |
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
|
|
|
|
|
15 |
IMAGES_DIR = "./images"
|
16 |
+
CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
|
17 |
|
18 |
def app():
|
19 |
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
|
20 |
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
|
21 |
+
image2caption = utils.load_captions(CAPTIONS_FILE)
|
22 |
|
23 |
st.title("Retrieve Images given Text")
|
24 |
st.markdown("""
|
|
|
74 |
query_vec = np.asarray(query_vec)
|
75 |
ids, distances = index.knnQuery(query_vec, k=10)
|
76 |
result_filenames = [filenames[id] for id in ids]
|
77 |
+
for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
|
78 |
+
caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
|
79 |
+
col1, col2, col3 = st.beta_columns([2, 10, 10])
|
80 |
+
col1.markdown("{:d}.".format(rank + 1))
|
81 |
+
col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
|
82 |
+
caption=caption)
|
83 |
+
caption_text = []
|
84 |
+
for caption in image2caption[result_filename]:
|
85 |
+
caption_text.append("* {:s}\n".format(caption))
|
86 |
+
col3.markdown("".join(caption_text))
|
87 |
+
st.markdown("---")
|
88 |
suggest_idx = -1
|
utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import matplotlib.pyplot as plt
|
2 |
import nmslib
|
3 |
import numpy as np
|
@@ -31,3 +32,17 @@ def load_model(model_path, baseline_model):
|
|
31 |
# processor = CLIPProcessor.from_pretrained(baseline_model)
|
32 |
processor = CLIPProcessor.from_pretrained(model_path)
|
33 |
return model, processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import matplotlib.pyplot as plt
|
3 |
import nmslib
|
4 |
import numpy as np
|
|
|
32 |
# processor = CLIPProcessor.from_pretrained(baseline_model)
|
33 |
processor = CLIPProcessor.from_pretrained(model_path)
|
34 |
return model, processor
|
35 |
+
|
36 |
+
|
37 |
+
@st.cache(allow_output_mutation=True)
|
38 |
+
def load_captions(caption_file):
|
39 |
+
image2caption = {}
|
40 |
+
with open(caption_file, "r") as fcap:
|
41 |
+
data = json.loads(fcap.read())
|
42 |
+
for image in data["images"]:
|
43 |
+
filename = image["filename"]
|
44 |
+
captions = []
|
45 |
+
for sentence in image["sentences"]:
|
46 |
+
captions.append(sentence["raw"])
|
47 |
+
image2caption[filename] = captions
|
48 |
+
return image2caption
|