chanicpanic commited on
Commit
67bc9b1
·
1 Parent(s): 5db88df

Generate multiple alternatives for image search

Browse files
Files changed (3) hide show
  1. app.py +3 -0
  2. search.py +29 -6
  3. vision.py +30 -8
app.py CHANGED
@@ -23,6 +23,9 @@ if "report_submitted" not in st.session_state:
23
  if "grascii" not in st.session_state:
24
  st.session_state["grascii"] = ""
25
 
 
 
 
26
  if st.session_state["report_submitted"]:
27
  st.toast("Thanks for the report!")
28
  st.session_state["report_submitted"] = False
 
23
  if "grascii" not in st.session_state:
24
  st.session_state["grascii"] = ""
25
 
26
+ if "alternatives" not in st.session_state:
27
+ st.session_state["alternatives"] = {}
28
+
29
  if st.session_state["report_submitted"]:
30
  st.toast("Thanks for the report!")
31
  st.session_state["report_submitted"] = False
search.py CHANGED
@@ -31,9 +31,10 @@ def load_images():
31
  image_map = load_images()
32
 
33
 
34
- def set_grascii():
35
  if "grascii_text_box" in st.session_state:
36
  st.session_state["grascii"] = st.session_state["grascii_text_box"]
 
37
 
38
 
39
  def write_grascii_search():
@@ -46,7 +47,10 @@ def write_grascii_search():
46
  placeholder = st.empty()
47
  if search_by == "text":
48
  placeholder.text_input(
49
- "Grascii", value=st.session_state["grascii"], key="grascii_text_box", max_chars=MAX_GRASCII_LENGTH
 
 
 
50
  )
51
  else:
52
  image_data = placeholder.file_uploader(
@@ -74,10 +78,14 @@ def write_grascii_search():
74
  alpha_composite = Image.alpha_composite(background, image)
75
 
76
  arr = np.array([alpha_composite.convert("L")])
77
- tokens = run_vision(arr)
78
- st.session_state["grascii"] = "".join(tokens)
 
 
 
 
79
  if save:
80
- save_image(image_data.getvalue(), "-".join(tokens))
81
 
82
  with st.expander("Options"):
83
  interpretation = st.radio(
@@ -157,7 +165,7 @@ def write_grascii_search():
157
  """,
158
  )
159
 
160
- st.form_submit_button("Search", on_click=set_grascii)
161
 
162
  grascii = st.session_state["grascii"]
163
 
@@ -180,9 +188,24 @@ def write_grascii_search():
180
  if grascii:
181
  st.error(f"Invalid Grascii\n```\n{e.context}\n```")
182
  else:
 
 
 
 
 
 
 
 
183
  write_results(grascii_results, grascii.upper(), "grascii")
184
 
185
 
 
 
 
 
 
 
 
186
  @st.fragment
187
  def write_results(results, term, key_prefix):
188
  rows = map(
 
31
  image_map = load_images()
32
 
33
 
34
+ def on_submit():
35
  if "grascii_text_box" in st.session_state:
36
  st.session_state["grascii"] = st.session_state["grascii_text_box"]
37
+ st.session_state["alternatives"] = {}
38
 
39
 
40
  def write_grascii_search():
 
47
  placeholder = st.empty()
48
  if search_by == "text":
49
  placeholder.text_input(
50
+ "Grascii",
51
+ value=st.session_state["grascii"],
52
+ key="grascii_text_box",
53
+ max_chars=MAX_GRASCII_LENGTH,
54
  )
55
  else:
56
  image_data = placeholder.file_uploader(
 
78
  alpha_composite = Image.alpha_composite(background, image)
79
 
80
  arr = np.array([alpha_composite.convert("L")])
81
+ predictions = run_vision(arr)
82
+ alternatives = {"".join(p): True for p in predictions}
83
+ if st.session_state["alternatives"] != alternatives:
84
+ st.session_state["alternatives"] = alternatives
85
+ st.session_state["grascii"] = "".join(predictions[0])
86
+
87
  if save:
88
+ save_image(image_data.getvalue(), "-".join(predictions[0]))
89
 
90
  with st.expander("Options"):
91
  interpretation = st.radio(
 
165
  """,
166
  )
167
 
168
+ st.form_submit_button("Search", on_click=on_submit)
169
 
170
  grascii = st.session_state["grascii"]
171
 
 
188
  if grascii:
189
  st.error(f"Invalid Grascii\n```\n{e.context}\n```")
190
  else:
191
+ if len(st.session_state["alternatives"]) > 1:
192
+ st.pills(
193
+ "Alternatives",
194
+ st.session_state["alternatives"],
195
+ key="alternative",
196
+ default=grascii,
197
+ on_change=on_alternative_selection,
198
+ )
199
  write_results(grascii_results, grascii.upper(), "grascii")
200
 
201
 
202
+ def on_alternative_selection():
203
+ if st.session_state["alternative"] is None:
204
+ st.session_state["alternative"] = st.session_state["grascii"]
205
+ else:
206
+ st.session_state["grascii"] = st.session_state["alternative"]
207
+
208
+
209
  @st.fragment
210
  def write_results(results, term, key_prefix):
211
  rows = map(
vision.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  from transformers import (
3
  PreTrainedTokenizerFast,
@@ -5,25 +7,45 @@ from transformers import (
5
  ViTImageProcessor,
6
  )
7
 
8
- model_name = "grascii/gregg-vision-v0.2.1"
 
 
 
9
 
10
 
11
- @st.cache_resource(show_spinner=f"Loading {model_name}")
12
  def load_model():
13
  model = VisionEncoderDecoderModel.from_pretrained(
14
- model_name, token=st.secrets.HF_TOKEN
15
  )
16
  tokenizer = PreTrainedTokenizerFast.from_pretrained(
17
- model_name,
18
  token=st.secrets.HF_TOKEN,
19
  )
20
- processor = ViTImageProcessor.from_pretrained(model_name, token=st.secrets.HF_TOKEN)
21
  return model, tokenizer, processor
22
 
23
 
24
- @st.cache_data(ttl=3600, show_spinner=f"Running {model_name}")
25
  def run_vision(image):
26
  model, tokenizer, processor = load_model()
27
  pixel_values = processor(image, return_tensors="pt").pixel_values
28
- generated = model.generate(pixel_values, max_new_tokens=12)[0]
29
- return tokenizer.convert_ids_to_tokens(generated, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
  import streamlit as st
4
  from transformers import (
5
  PreTrainedTokenizerFast,
 
7
  ViTImageProcessor,
8
  )
9
 
10
+
11
+ MODEL_NAME = "grascii/gregg-vision-v0.2.1"
12
+ MIN_LOG_PROB = math.log(0.5)
13
+ NUM_BEAMS = 3
14
 
15
 
16
+ @st.cache_resource(show_spinner=f"Loading {MODEL_NAME}")
17
  def load_model():
18
  model = VisionEncoderDecoderModel.from_pretrained(
19
+ MODEL_NAME, token=st.secrets.HF_TOKEN
20
  )
21
  tokenizer = PreTrainedTokenizerFast.from_pretrained(
22
+ MODEL_NAME,
23
  token=st.secrets.HF_TOKEN,
24
  )
25
+ processor = ViTImageProcessor.from_pretrained(MODEL_NAME, token=st.secrets.HF_TOKEN)
26
  return model, tokenizer, processor
27
 
28
 
29
+ @st.cache_data(ttl=3600, show_spinner=f"Running {MODEL_NAME}")
30
  def run_vision(image):
31
  model, tokenizer, processor = load_model()
32
  pixel_values = processor(image, return_tensors="pt").pixel_values
33
+ generated = model.generate(
34
+ pixel_values,
35
+ max_new_tokens=12,
36
+ num_beams=NUM_BEAMS,
37
+ num_return_sequences=NUM_BEAMS,
38
+ output_scores=True,
39
+ return_dict_in_generate=True,
40
+ )
41
+ return [
42
+ tokenizer.convert_ids_to_tokens(
43
+ generated["sequences"][0], skip_special_tokens=True
44
+ )
45
+ ] + [
46
+ tokenizer.convert_ids_to_tokens(seq, skip_special_tokens=True)
47
+ for seq, score in zip(
48
+ generated["sequences"][1:], generated["sequences_scores"][1:]
49
+ )
50
+ if score > MIN_LOG_PROB
51
+ ]