Commit
·
67bc9b1
1
Parent(s):
5db88df
Generate multiple alternatives for image search
Browse files
app.py
CHANGED
@@ -23,6 +23,9 @@ if "report_submitted" not in st.session_state:
|
|
23 |
if "grascii" not in st.session_state:
|
24 |
st.session_state["grascii"] = ""
|
25 |
|
|
|
|
|
|
|
26 |
if st.session_state["report_submitted"]:
|
27 |
st.toast("Thanks for the report!")
|
28 |
st.session_state["report_submitted"] = False
|
|
|
23 |
if "grascii" not in st.session_state:
|
24 |
st.session_state["grascii"] = ""
|
25 |
|
26 |
+
if "alternatives" not in st.session_state:
|
27 |
+
st.session_state["alternatives"] = {}
|
28 |
+
|
29 |
if st.session_state["report_submitted"]:
|
30 |
st.toast("Thanks for the report!")
|
31 |
st.session_state["report_submitted"] = False
|
search.py
CHANGED
@@ -31,9 +31,10 @@ def load_images():
|
|
31 |
image_map = load_images()
|
32 |
|
33 |
|
34 |
-
def
|
35 |
if "grascii_text_box" in st.session_state:
|
36 |
st.session_state["grascii"] = st.session_state["grascii_text_box"]
|
|
|
37 |
|
38 |
|
39 |
def write_grascii_search():
|
@@ -46,7 +47,10 @@ def write_grascii_search():
|
|
46 |
placeholder = st.empty()
|
47 |
if search_by == "text":
|
48 |
placeholder.text_input(
|
49 |
-
"Grascii",
|
|
|
|
|
|
|
50 |
)
|
51 |
else:
|
52 |
image_data = placeholder.file_uploader(
|
@@ -74,10 +78,14 @@ def write_grascii_search():
|
|
74 |
alpha_composite = Image.alpha_composite(background, image)
|
75 |
|
76 |
arr = np.array([alpha_composite.convert("L")])
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
79 |
if save:
|
80 |
-
save_image(image_data.getvalue(), "-".join(
|
81 |
|
82 |
with st.expander("Options"):
|
83 |
interpretation = st.radio(
|
@@ -157,7 +165,7 @@ def write_grascii_search():
|
|
157 |
""",
|
158 |
)
|
159 |
|
160 |
-
st.form_submit_button("Search", on_click=
|
161 |
|
162 |
grascii = st.session_state["grascii"]
|
163 |
|
@@ -180,9 +188,24 @@ def write_grascii_search():
|
|
180 |
if grascii:
|
181 |
st.error(f"Invalid Grascii\n```\n{e.context}\n```")
|
182 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
write_results(grascii_results, grascii.upper(), "grascii")
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
@st.fragment
|
187 |
def write_results(results, term, key_prefix):
|
188 |
rows = map(
|
|
|
31 |
image_map = load_images()
|
32 |
|
33 |
|
34 |
+
def on_submit():
|
35 |
if "grascii_text_box" in st.session_state:
|
36 |
st.session_state["grascii"] = st.session_state["grascii_text_box"]
|
37 |
+
st.session_state["alternatives"] = {}
|
38 |
|
39 |
|
40 |
def write_grascii_search():
|
|
|
47 |
placeholder = st.empty()
|
48 |
if search_by == "text":
|
49 |
placeholder.text_input(
|
50 |
+
"Grascii",
|
51 |
+
value=st.session_state["grascii"],
|
52 |
+
key="grascii_text_box",
|
53 |
+
max_chars=MAX_GRASCII_LENGTH,
|
54 |
)
|
55 |
else:
|
56 |
image_data = placeholder.file_uploader(
|
|
|
78 |
alpha_composite = Image.alpha_composite(background, image)
|
79 |
|
80 |
arr = np.array([alpha_composite.convert("L")])
|
81 |
+
predictions = run_vision(arr)
|
82 |
+
alternatives = {"".join(p): True for p in predictions}
|
83 |
+
if st.session_state["alternatives"] != alternatives:
|
84 |
+
st.session_state["alternatives"] = alternatives
|
85 |
+
st.session_state["grascii"] = "".join(predictions[0])
|
86 |
+
|
87 |
if save:
|
88 |
+
save_image(image_data.getvalue(), "-".join(predictions[0]))
|
89 |
|
90 |
with st.expander("Options"):
|
91 |
interpretation = st.radio(
|
|
|
165 |
""",
|
166 |
)
|
167 |
|
168 |
+
st.form_submit_button("Search", on_click=on_submit)
|
169 |
|
170 |
grascii = st.session_state["grascii"]
|
171 |
|
|
|
188 |
if grascii:
|
189 |
st.error(f"Invalid Grascii\n```\n{e.context}\n```")
|
190 |
else:
|
191 |
+
if len(st.session_state["alternatives"]) > 1:
|
192 |
+
st.pills(
|
193 |
+
"Alternatives",
|
194 |
+
st.session_state["alternatives"],
|
195 |
+
key="alternative",
|
196 |
+
default=grascii,
|
197 |
+
on_change=on_alternative_selection,
|
198 |
+
)
|
199 |
write_results(grascii_results, grascii.upper(), "grascii")
|
200 |
|
201 |
|
202 |
+
def on_alternative_selection():
|
203 |
+
if st.session_state["alternative"] is None:
|
204 |
+
st.session_state["alternative"] = st.session_state["grascii"]
|
205 |
+
else:
|
206 |
+
st.session_state["grascii"] = st.session_state["alternative"]
|
207 |
+
|
208 |
+
|
209 |
@st.fragment
|
210 |
def write_results(results, term, key_prefix):
|
211 |
rows = map(
|
vision.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import (
|
3 |
PreTrainedTokenizerFast,
|
@@ -5,25 +7,45 @@ from transformers import (
|
|
5 |
ViTImageProcessor,
|
6 |
)
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
-
@st.cache_resource(show_spinner=f"Loading {
|
12 |
def load_model():
|
13 |
model = VisionEncoderDecoderModel.from_pretrained(
|
14 |
-
|
15 |
)
|
16 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
17 |
-
|
18 |
token=st.secrets.HF_TOKEN,
|
19 |
)
|
20 |
-
processor = ViTImageProcessor.from_pretrained(
|
21 |
return model, tokenizer, processor
|
22 |
|
23 |
|
24 |
-
@st.cache_data(ttl=3600, show_spinner=f"Running {
|
25 |
def run_vision(image):
|
26 |
model, tokenizer, processor = load_model()
|
27 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
28 |
-
generated = model.generate(
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
import streamlit as st
|
4 |
from transformers import (
|
5 |
PreTrainedTokenizerFast,
|
|
|
7 |
ViTImageProcessor,
|
8 |
)
|
9 |
|
10 |
+
|
11 |
+
MODEL_NAME = "grascii/gregg-vision-v0.2.1"
|
12 |
+
MIN_LOG_PROB = math.log(0.5)
|
13 |
+
NUM_BEAMS = 3
|
14 |
|
15 |
|
16 |
+
@st.cache_resource(show_spinner=f"Loading {MODEL_NAME}")
|
17 |
def load_model():
|
18 |
model = VisionEncoderDecoderModel.from_pretrained(
|
19 |
+
MODEL_NAME, token=st.secrets.HF_TOKEN
|
20 |
)
|
21 |
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
22 |
+
MODEL_NAME,
|
23 |
token=st.secrets.HF_TOKEN,
|
24 |
)
|
25 |
+
processor = ViTImageProcessor.from_pretrained(MODEL_NAME, token=st.secrets.HF_TOKEN)
|
26 |
return model, tokenizer, processor
|
27 |
|
28 |
|
29 |
+
@st.cache_data(ttl=3600, show_spinner=f"Running {MODEL_NAME}")
|
30 |
def run_vision(image):
|
31 |
model, tokenizer, processor = load_model()
|
32 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
33 |
+
generated = model.generate(
|
34 |
+
pixel_values,
|
35 |
+
max_new_tokens=12,
|
36 |
+
num_beams=NUM_BEAMS,
|
37 |
+
num_return_sequences=NUM_BEAMS,
|
38 |
+
output_scores=True,
|
39 |
+
return_dict_in_generate=True,
|
40 |
+
)
|
41 |
+
return [
|
42 |
+
tokenizer.convert_ids_to_tokens(
|
43 |
+
generated["sequences"][0], skip_special_tokens=True
|
44 |
+
)
|
45 |
+
] + [
|
46 |
+
tokenizer.convert_ids_to_tokens(seq, skip_special_tokens=True)
|
47 |
+
for seq, score in zip(
|
48 |
+
generated["sequences"][1:], generated["sequences_scores"][1:]
|
49 |
+
)
|
50 |
+
if score > MIN_LOG_PROB
|
51 |
+
]
|