cwkuo commited on
Commit
bf5fb05
Β·
1 Parent(s): 9d121b0

some quick fix

Browse files
app.py CHANGED
@@ -22,6 +22,7 @@ no_change_btn = gr.Button.update()
22
  enable_btn = gr.Button.update(interactive=True)
23
  disable_btn = gr.Button.update(interactive=False)
24
  knwl_none = (None, ) * 30
 
25
  moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
26
 
27
 
@@ -29,6 +30,10 @@ def violates_moderation(text):
29
  """
30
  Check whether the text violates OpenAI moderation API.
31
  """
 
 
 
 
32
  url = "https://api.openai.com/v1/moderations"
33
  headers = {
34
  "Content-Type": "application/json",
@@ -60,31 +65,32 @@ def regenerate(state: Conversation):
60
  prev_human_msg[1] = prev_human_msg[1][:2]
61
  state.skip_next = False
62
 
63
- return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn)
64
 
65
 
66
  def clear_history():
67
  state = default_conversation.copy()
68
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2 + knwl_none
69
 
70
 
71
  def add_text(state: Conversation, text, image):
72
  if len(text) <= 0 and image is None:
73
  state.skip_next = True
74
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 2
75
 
76
  if violates_moderation(text):
77
  state.skip_next = True
78
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 2
79
 
80
- text = (text, image)
81
- if len(state.get_images(return_pil=True)) > 0:
82
- state = default_conversation.copy()
 
83
  state.append_message(state.roles[0], text)
84
  state.append_message(state.roles[1], None)
85
  state.skip_next = False
86
 
87
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 2
88
 
89
 
90
  def search(image, pos, topk, knwl_db, knwl_idx):
@@ -150,9 +156,10 @@ def retrieve_knowledge(image):
150
  return knwl_embd, knwl_text
151
 
152
 
153
- def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
 
154
  if state.skip_next: # This generate call is skipped due to invalid inputs
155
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2 + knwl_none
156
  return
157
 
158
  if len(state.messages) == state.offset + 2: # First round of conversation
@@ -177,11 +184,16 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
177
  for pos in range(knwl_pos):
178
  try:
179
  txt = ""
180
- for k, v in knwl[query_type][str(pos)].items():
181
  v = ", ".join([vi.replace("_", " ") for vi in v])
182
  txt += f"**[{k.upper()}]:** {v}\n\n"
183
  knwl_txt[idx] += txt
184
- knwl_img[idx] = images[pos]
 
 
 
 
 
185
  except KeyError:
186
  pass
187
  idx += 1
@@ -189,13 +201,13 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
189
  else:
190
  knwl_embd = None
191
  knwl_vis = knwl_none
 
192
 
193
  # generate output
194
- prompt = state.get_prompt()
195
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
196
  image_pt = image_trans(image).to(device).unsqueeze(0)
197
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
198
-
199
  if bool(do_beam_search):
200
  new_text = gptk_model.generate(
201
  samples=samples,
@@ -203,6 +215,7 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
203
  max_length=min(int(max_new_tokens), 1024),
204
  top_p=float(top_p),
205
  temperature=float(temperature),
 
206
  auto_cast=True
207
  )[0]
208
  streamer = [new_text, ]
@@ -220,6 +233,7 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
220
  temperature=float(temperature),
221
  streamer=streamer,
222
  num_beams=1,
 
223
  auto_cast=True
224
  )
225
  )
@@ -229,10 +243,10 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
229
  for new_text in streamer:
230
  generated_text += new_text
231
  state.messages[-1][-1] = generated_text + "β–Œ"
232
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 + knwl_vis
233
  time.sleep(0.03)
234
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
235
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2 + knwl_vis
236
 
237
 
238
  title_markdown = ("""
@@ -268,15 +282,18 @@ def build_demo():
268
  ["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
269
  ["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
270
  ["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
271
- ["examples/titanic.jpg", "What happen in the scene in this movie?"],
272
  ], inputs=[imagebox, textbox])
273
 
274
  imagebox.render()
275
- textbox.render()
276
  with gr.Row():
277
- submit_btn = gr.Button(value="πŸ“ Submit")
278
- regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
279
- clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
 
 
 
 
 
280
 
281
  with gr.Accordion("Parameters", open=True):
282
  with gr.Row():
@@ -290,7 +307,7 @@ def build_demo():
290
  with gr.Column(scale=6):
291
  chatbot = gr.Chatbot(elem_id="chatbot", label="GPT-K Chatbot", height=550)
292
 
293
- gr.Markdown("Retrieved Knowledge")
294
  knwl_img, knwl_txt = [], []
295
  for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
296
  with gr.Tab(query_type):
@@ -307,7 +324,7 @@ def build_demo():
307
  gr.Markdown(learn_more_markdown)
308
 
309
  # Register listeners
310
- btn_list = [regenerate_btn, clear_btn]
311
  regenerate_btn.click(
312
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
313
  ).then(
 
22
  enable_btn = gr.Button.update(interactive=True)
23
  disable_btn = gr.Button.update(interactive=False)
24
  knwl_none = (None, ) * 30
25
+ knwl_unchange = (gr.Image.update(), ) * 15 + (gr.Textbox.update(), ) * 15
26
  moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
27
 
28
 
 
30
  """
31
  Check whether the text violates OpenAI moderation API.
32
  """
33
+ if "OPENAI_API_KEY" not in os.environ:
34
+ print("OPENAI_API_KEY not found, skip content moderation check...")
35
+ return True
36
+
37
  url = "https://api.openai.com/v1/moderations"
38
  headers = {
39
  "Content-Type": "application/json",
 
65
  prev_human_msg[1] = prev_human_msg[1][:2]
66
  state.skip_next = False
67
 
68
+ return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn, disable_btn)
69
 
70
 
71
  def clear_history():
72
  state = default_conversation.copy()
73
+ return (state, state.to_gradio_chatbot(), "", None) + (enable_btn, disable_btn, disable_btn) + knwl_none
74
 
75
 
76
  def add_text(state: Conversation, text, image):
77
  if len(text) <= 0 and image is None:
78
  state.skip_next = True
79
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 3
80
 
81
  if violates_moderation(text):
82
  state.skip_next = True
83
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 3
84
 
85
+ if image is not None:
86
+ text = (text, image)
87
+ if len(state.get_images(return_pil=True)) > 0:
88
+ state = default_conversation.copy()
89
  state.append_message(state.roles[0], text)
90
  state.append_message(state.roles[1], None)
91
  state.skip_next = False
92
 
93
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3
94
 
95
 
96
  def search(image, pos, topk, knwl_db, knwl_idx):
 
156
  return knwl_embd, knwl_text
157
 
158
 
159
+ @torch.inference_mode()
160
+ def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
161
  if state.skip_next: # This generate call is skipped due to invalid inputs
162
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
163
  return
164
 
165
  if len(state.messages) == state.offset + 2: # First round of conversation
 
184
  for pos in range(knwl_pos):
185
  try:
186
  txt = ""
187
+ for k, v in knwl[query_type][pos].items():
188
  v = ", ".join([vi.replace("_", " ") for vi in v])
189
  txt += f"**[{k.upper()}]:** {v}\n\n"
190
  knwl_txt[idx] += txt
191
+
192
+ img = images[pos]
193
+ img = query_trans.transforms[0](img)
194
+ img = query_trans.transforms[1](img)
195
+ img = query_trans.transforms[2](img)
196
+ knwl_img[idx] = img
197
  except KeyError:
198
  pass
199
  idx += 1
 
201
  else:
202
  knwl_embd = None
203
  knwl_vis = knwl_none
204
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
205
 
206
  # generate output
207
+ prompt = state.get_prompt().replace("USER: <image>\n", "")
208
  prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
209
  image_pt = image_trans(image).to(device).unsqueeze(0)
210
  samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
 
211
  if bool(do_beam_search):
212
  new_text = gptk_model.generate(
213
  samples=samples,
 
215
  max_length=min(int(max_new_tokens), 1024),
216
  top_p=float(top_p),
217
  temperature=float(temperature),
218
+ length_penalty=0.0,
219
  auto_cast=True
220
  )[0]
221
  streamer = [new_text, ]
 
233
  temperature=float(temperature),
234
  streamer=streamer,
235
  num_beams=1,
236
+ length_penalty=0.0,
237
  auto_cast=True
238
  )
239
  )
 
243
  for new_text in streamer:
244
  generated_text += new_text
245
  state.messages[-1][-1] = generated_text + "β–Œ"
246
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_unchange
247
  time.sleep(0.03)
248
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
249
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 3 + knwl_unchange
250
 
251
 
252
  title_markdown = ("""
 
282
  ["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
283
  ["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
284
  ["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
 
285
  ], inputs=[imagebox, textbox])
286
 
287
  imagebox.render()
 
288
  with gr.Row():
289
+ with gr.Column(scale=8):
290
+ textbox.render()
291
+ with gr.Column(scale=1, min_width=60):
292
+ submit_btn = gr.Button(value="Submit")
293
+
294
+ with gr.Row():
295
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False, scale=1)
296
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False, scale=1)
297
 
298
  with gr.Accordion("Parameters", open=True):
299
  with gr.Row():
 
307
  with gr.Column(scale=6):
308
  chatbot = gr.Chatbot(elem_id="chatbot", label="GPT-K Chatbot", height=550)
309
 
310
+ gr.Markdown("## Retrieved Knowledge")
311
  knwl_img, knwl_txt = [], []
312
  for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
313
  with gr.Tab(query_type):
 
324
  gr.Markdown(learn_more_markdown)
325
 
326
  # Register listeners
327
+ btn_list = [submit_btn, regenerate_btn, clear_btn]
328
  regenerate_btn.click(
329
  regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
330
  ).then(
examples/titanic.jpg DELETED

Git LFS Details

  • SHA256: e730a4a2d3efd7a99d5e120d22000cc51cf81176e32aa677fd2be1ea8dfb4a63
  • Pointer size: 131 Bytes
  • Size of remote file: 439 kB
model/ckpt/gptk-vicuna7b.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb27e6bbdc6f93ac950d265287c8388824f106e17bcab5cf5254810ca9c6790f
3
- size 564340835
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:502c7a85d5e0d17eb4e823ed8779565dbac832fa16fd659b69e39b9b024c9d26
3
+ size 564340993
requirements.txt CHANGED
@@ -1,9 +1,13 @@
1
- h5py>=3.8.0
2
- transformers==4.30.2
3
- faiss-gpu==1.7.2
4
- timm==0.4.12
5
- openai
6
  --extra-index-url https://download.pytorch.org/whl/cu113
7
  torch==1.11.0+cu113
8
  torchvision==0.12.0+cu113
9
  torchaudio==0.11.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch==1.11.0+cu113
3
  torchvision==0.12.0+cu113
4
  torchaudio==0.11.0
5
+
6
+ transformers==4.30.2
7
+ faiss-gpu==1.7.2
8
+ timm==0.4.12
9
+ openai
10
+ open_clip_torch
11
+ omegaconf
12
+ h5py>=3.8.0
13
+ spacy>=3.5.0