Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
Β·
bf5fb05
1
Parent(s):
9d121b0
some quick fix
Browse files- app.py +40 -23
- examples/titanic.jpg +0 -3
- model/ckpt/gptk-vicuna7b.pt +2 -2
- requirements.txt +9 -5
app.py
CHANGED
@@ -22,6 +22,7 @@ no_change_btn = gr.Button.update()
|
|
22 |
enable_btn = gr.Button.update(interactive=True)
|
23 |
disable_btn = gr.Button.update(interactive=False)
|
24 |
knwl_none = (None, ) * 30
|
|
|
25 |
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
26 |
|
27 |
|
@@ -29,6 +30,10 @@ def violates_moderation(text):
|
|
29 |
"""
|
30 |
Check whether the text violates OpenAI moderation API.
|
31 |
"""
|
|
|
|
|
|
|
|
|
32 |
url = "https://api.openai.com/v1/moderations"
|
33 |
headers = {
|
34 |
"Content-Type": "application/json",
|
@@ -60,31 +65,32 @@ def regenerate(state: Conversation):
|
|
60 |
prev_human_msg[1] = prev_human_msg[1][:2]
|
61 |
state.skip_next = False
|
62 |
|
63 |
-
return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn)
|
64 |
|
65 |
|
66 |
def clear_history():
|
67 |
state = default_conversation.copy()
|
68 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,)
|
69 |
|
70 |
|
71 |
def add_text(state: Conversation, text, image):
|
72 |
if len(text) <= 0 and image is None:
|
73 |
state.skip_next = True
|
74 |
-
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) *
|
75 |
|
76 |
if violates_moderation(text):
|
77 |
state.skip_next = True
|
78 |
-
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) *
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
state
|
|
|
83 |
state.append_message(state.roles[0], text)
|
84 |
state.append_message(state.roles[1], None)
|
85 |
state.skip_next = False
|
86 |
|
87 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) *
|
88 |
|
89 |
|
90 |
def search(image, pos, topk, knwl_db, knwl_idx):
|
@@ -150,9 +156,10 @@ def retrieve_knowledge(image):
|
|
150 |
return knwl_embd, knwl_text
|
151 |
|
152 |
|
153 |
-
|
|
|
154 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
155 |
-
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) *
|
156 |
return
|
157 |
|
158 |
if len(state.messages) == state.offset + 2: # First round of conversation
|
@@ -177,11 +184,16 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
|
|
177 |
for pos in range(knwl_pos):
|
178 |
try:
|
179 |
txt = ""
|
180 |
-
for k, v in knwl[query_type][
|
181 |
v = ", ".join([vi.replace("_", " ") for vi in v])
|
182 |
txt += f"**[{k.upper()}]:** {v}\n\n"
|
183 |
knwl_txt[idx] += txt
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
except KeyError:
|
186 |
pass
|
187 |
idx += 1
|
@@ -189,13 +201,13 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
|
|
189 |
else:
|
190 |
knwl_embd = None
|
191 |
knwl_vis = knwl_none
|
|
|
192 |
|
193 |
# generate output
|
194 |
-
prompt = state.get_prompt()
|
195 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
196 |
image_pt = image_trans(image).to(device).unsqueeze(0)
|
197 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
198 |
-
|
199 |
if bool(do_beam_search):
|
200 |
new_text = gptk_model.generate(
|
201 |
samples=samples,
|
@@ -203,6 +215,7 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
|
|
203 |
max_length=min(int(max_new_tokens), 1024),
|
204 |
top_p=float(top_p),
|
205 |
temperature=float(temperature),
|
|
|
206 |
auto_cast=True
|
207 |
)[0]
|
208 |
streamer = [new_text, ]
|
@@ -220,6 +233,7 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
|
|
220 |
temperature=float(temperature),
|
221 |
streamer=streamer,
|
222 |
num_beams=1,
|
|
|
223 |
auto_cast=True
|
224 |
)
|
225 |
)
|
@@ -229,10 +243,10 @@ def generate(state, temperature, top_p, max_new_tokens, add_knwl, do_sampling, d
|
|
229 |
for new_text in streamer:
|
230 |
generated_text += new_text
|
231 |
state.messages[-1][-1] = generated_text + "β"
|
232 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) *
|
233 |
time.sleep(0.03)
|
234 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
235 |
-
yield (state, state.to_gradio_chatbot()) + (enable_btn,) *
|
236 |
|
237 |
|
238 |
title_markdown = ("""
|
@@ -268,15 +282,18 @@ def build_demo():
|
|
268 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
269 |
["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
|
270 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
271 |
-
["examples/titanic.jpg", "What happen in the scene in this movie?"],
|
272 |
], inputs=[imagebox, textbox])
|
273 |
|
274 |
imagebox.render()
|
275 |
-
textbox.render()
|
276 |
with gr.Row():
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
with gr.Accordion("Parameters", open=True):
|
282 |
with gr.Row():
|
@@ -290,7 +307,7 @@ def build_demo():
|
|
290 |
with gr.Column(scale=6):
|
291 |
chatbot = gr.Chatbot(elem_id="chatbot", label="GPT-K Chatbot", height=550)
|
292 |
|
293 |
-
gr.Markdown("Retrieved Knowledge")
|
294 |
knwl_img, knwl_txt = [], []
|
295 |
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
296 |
with gr.Tab(query_type):
|
@@ -307,7 +324,7 @@ def build_demo():
|
|
307 |
gr.Markdown(learn_more_markdown)
|
308 |
|
309 |
# Register listeners
|
310 |
-
btn_list = [regenerate_btn, clear_btn]
|
311 |
regenerate_btn.click(
|
312 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
313 |
).then(
|
|
|
22 |
enable_btn = gr.Button.update(interactive=True)
|
23 |
disable_btn = gr.Button.update(interactive=False)
|
24 |
knwl_none = (None, ) * 30
|
25 |
+
knwl_unchange = (gr.Image.update(), ) * 15 + (gr.Textbox.update(), ) * 15
|
26 |
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
27 |
|
28 |
|
|
|
30 |
"""
|
31 |
Check whether the text violates OpenAI moderation API.
|
32 |
"""
|
33 |
+
if "OPENAI_API_KEY" not in os.environ:
|
34 |
+
print("OPENAI_API_KEY not found, skip content moderation check...")
|
35 |
+
return True
|
36 |
+
|
37 |
url = "https://api.openai.com/v1/moderations"
|
38 |
headers = {
|
39 |
"Content-Type": "application/json",
|
|
|
65 |
prev_human_msg[1] = prev_human_msg[1][:2]
|
66 |
state.skip_next = False
|
67 |
|
68 |
+
return (state, state.to_gradio_chatbot(), "", None, disable_btn, disable_btn, disable_btn)
|
69 |
|
70 |
|
71 |
def clear_history():
|
72 |
state = default_conversation.copy()
|
73 |
+
return (state, state.to_gradio_chatbot(), "", None) + (enable_btn, disable_btn, disable_btn) + knwl_none
|
74 |
|
75 |
|
76 |
def add_text(state: Conversation, text, image):
|
77 |
if len(text) <= 0 and image is None:
|
78 |
state.skip_next = True
|
79 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 3
|
80 |
|
81 |
if violates_moderation(text):
|
82 |
state.skip_next = True
|
83 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 3
|
84 |
|
85 |
+
if image is not None:
|
86 |
+
text = (text, image)
|
87 |
+
if len(state.get_images(return_pil=True)) > 0:
|
88 |
+
state = default_conversation.copy()
|
89 |
state.append_message(state.roles[0], text)
|
90 |
state.append_message(state.roles[1], None)
|
91 |
state.skip_next = False
|
92 |
|
93 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3
|
94 |
|
95 |
|
96 |
def search(image, pos, topk, knwl_db, knwl_idx):
|
|
|
156 |
return knwl_embd, knwl_text
|
157 |
|
158 |
|
159 |
+
@torch.inference_mode()
|
160 |
+
def generate(state: Conversation, temperature, top_p, max_new_tokens, add_knwl, do_sampling, do_beam_search):
|
161 |
if state.skip_next: # This generate call is skipped due to invalid inputs
|
162 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 + knwl_unchange
|
163 |
return
|
164 |
|
165 |
if len(state.messages) == state.offset + 2: # First round of conversation
|
|
|
184 |
for pos in range(knwl_pos):
|
185 |
try:
|
186 |
txt = ""
|
187 |
+
for k, v in knwl[query_type][pos].items():
|
188 |
v = ", ".join([vi.replace("_", " ") for vi in v])
|
189 |
txt += f"**[{k.upper()}]:** {v}\n\n"
|
190 |
knwl_txt[idx] += txt
|
191 |
+
|
192 |
+
img = images[pos]
|
193 |
+
img = query_trans.transforms[0](img)
|
194 |
+
img = query_trans.transforms[1](img)
|
195 |
+
img = query_trans.transforms[2](img)
|
196 |
+
knwl_img[idx] = img
|
197 |
except KeyError:
|
198 |
pass
|
199 |
idx += 1
|
|
|
201 |
else:
|
202 |
knwl_embd = None
|
203 |
knwl_vis = knwl_none
|
204 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_vis
|
205 |
|
206 |
# generate output
|
207 |
+
prompt = state.get_prompt().replace("USER: <image>\n", "")
|
208 |
prompt = prompt.split("USER:")[-1].replace("ASSISTANT:", "")
|
209 |
image_pt = image_trans(image).to(device).unsqueeze(0)
|
210 |
samples = {"image": image_pt, "knowledge": knwl_embd, "prompt": prompt}
|
|
|
211 |
if bool(do_beam_search):
|
212 |
new_text = gptk_model.generate(
|
213 |
samples=samples,
|
|
|
215 |
max_length=min(int(max_new_tokens), 1024),
|
216 |
top_p=float(top_p),
|
217 |
temperature=float(temperature),
|
218 |
+
length_penalty=0.0,
|
219 |
auto_cast=True
|
220 |
)[0]
|
221 |
streamer = [new_text, ]
|
|
|
233 |
temperature=float(temperature),
|
234 |
streamer=streamer,
|
235 |
num_beams=1,
|
236 |
+
length_penalty=0.0,
|
237 |
auto_cast=True
|
238 |
)
|
239 |
)
|
|
|
243 |
for new_text in streamer:
|
244 |
generated_text += new_text
|
245 |
state.messages[-1][-1] = generated_text + "β"
|
246 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 + knwl_unchange
|
247 |
time.sleep(0.03)
|
248 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
249 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 3 + knwl_unchange
|
250 |
|
251 |
|
252 |
title_markdown = ("""
|
|
|
282 |
["examples/mona_lisa_dog.jpg", "Describe this photo in detail."],
|
283 |
["examples/diamond_head.jpg", "What is the name of this famous sight in the photo?"],
|
284 |
["examples/horseshoe_bend.jpg", "What are the possible reasons of the formation of this sight?"],
|
|
|
285 |
], inputs=[imagebox, textbox])
|
286 |
|
287 |
imagebox.render()
|
|
|
288 |
with gr.Row():
|
289 |
+
with gr.Column(scale=8):
|
290 |
+
textbox.render()
|
291 |
+
with gr.Column(scale=1, min_width=60):
|
292 |
+
submit_btn = gr.Button(value="Submit")
|
293 |
+
|
294 |
+
with gr.Row():
|
295 |
+
regenerate_btn = gr.Button(value="π Regenerate", interactive=False, scale=1)
|
296 |
+
clear_btn = gr.Button(value="ποΈ Clear", interactive=False, scale=1)
|
297 |
|
298 |
with gr.Accordion("Parameters", open=True):
|
299 |
with gr.Row():
|
|
|
307 |
with gr.Column(scale=6):
|
308 |
chatbot = gr.Chatbot(elem_id="chatbot", label="GPT-K Chatbot", height=550)
|
309 |
|
310 |
+
gr.Markdown("## Retrieved Knowledge")
|
311 |
knwl_img, knwl_txt = [], []
|
312 |
for query_type, knwl_pos in (("whole", 1), ("five", 5), ("nine", 9)):
|
313 |
with gr.Tab(query_type):
|
|
|
324 |
gr.Markdown(learn_more_markdown)
|
325 |
|
326 |
# Register listeners
|
327 |
+
btn_list = [submit_btn, regenerate_btn, clear_btn]
|
328 |
regenerate_btn.click(
|
329 |
regenerate, [state], [state, chatbot, textbox, imagebox] + btn_list
|
330 |
).then(
|
examples/titanic.jpg
DELETED
Git LFS Details
|
model/ckpt/gptk-vicuna7b.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:502c7a85d5e0d17eb4e823ed8779565dbac832fa16fd659b69e39b9b024c9d26
|
3 |
+
size 564340993
|
requirements.txt
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
-
h5py>=3.8.0
|
2 |
-
transformers==4.30.2
|
3 |
-
faiss-gpu==1.7.2
|
4 |
-
timm==0.4.12
|
5 |
-
openai
|
6 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
7 |
torch==1.11.0+cu113
|
8 |
torchvision==0.12.0+cu113
|
9 |
torchaudio==0.11.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
torch==1.11.0+cu113
|
3 |
torchvision==0.12.0+cu113
|
4 |
torchaudio==0.11.0
|
5 |
+
|
6 |
+
transformers==4.30.2
|
7 |
+
faiss-gpu==1.7.2
|
8 |
+
timm==0.4.12
|
9 |
+
openai
|
10 |
+
open_clip_torch
|
11 |
+
omegaconf
|
12 |
+
h5py>=3.8.0
|
13 |
+
spacy>=3.5.0
|