yuexiang96 commited on
Commit
17c6e95
·
verified ·
1 Parent(s): 70f1fe7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -75
app.py CHANGED
@@ -62,6 +62,7 @@ repo_name = os.environ["LOG_REPO"]
62
 
63
  external_log_dir = "./logs"
64
  LOGDIR = external_log_dir
 
65
 
66
 
67
  def install_gradio_4_35_0():
@@ -87,6 +88,38 @@ def get_conv_log_filename():
87
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
88
  return name
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class InferenceDemo(object):
91
  def __init__(
92
  self, args, model_path, tokenizer, model, image_processor, context_len
@@ -125,6 +158,22 @@ class InferenceDemo(object):
125
  self.conversation = conv_templates[args.conv_mode].copy()
126
  self.num_frames = args.num_frames
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  def is_valid_video_filename(name):
130
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
@@ -178,13 +227,6 @@ def load_image(image_file):
178
  return image
179
 
180
 
181
- def clear_history(history):
182
-
183
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
184
-
185
- return None
186
-
187
-
188
  def clear_response(history):
189
  for index_conv in range(1, len(history)):
190
  # loop until get a text response from our model.
@@ -195,40 +237,69 @@ def clear_response(history):
195
  history = history[:-index_conv]
196
  return history, question
197
 
 
 
 
 
 
 
 
198
 
199
- # def print_like_dislike(x: gr.LikeData):
200
- # print(x.index, x.value, x.liked)
201
 
202
 
203
  def add_message(history, message):
204
- # history=[]
205
- global our_chatbot
206
- if len(history) == 0:
207
- our_chatbot = InferenceDemo(
208
- args, model_path, tokenizer, model, image_processor, context_len
209
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
- for x in message["files"]:
212
- history.append(((x,), None))
213
- if message["text"] is not None:
214
- history.append((message["text"], None))
215
- return history, gr.MultimodalTextbox(value=None, interactive=False)
216
 
217
 
218
  @spaces.GPU
219
  def bot(history, temperature, top_p, max_output_tokens):
220
- print("### turn start history",history)
221
- print("### turn start conv",our_chatbot.conversation)
222
  text = history[-1][0]
223
  images_this_term = []
224
  text_this_term = ""
225
- # import pdb;pdb.set_trace()
226
  num_new_images = 0
 
227
  for i, message in enumerate(history[:-1]):
228
  if type(message[0]) is tuple:
 
 
 
 
 
229
  images_this_term.append(message[0][0])
230
  if is_valid_video_filename(message[0][0]):
231
- # 不接受视频
232
  raise ValueError("Video is not supported")
233
  num_new_images += our_chatbot.num_frames
234
  elif is_valid_image_filename(message[0][0]):
@@ -236,15 +307,10 @@ def bot(history, temperature, top_p, max_output_tokens):
236
  num_new_images += 1
237
  else:
238
  raise ValueError("Invalid image file")
 
239
  else:
240
  num_new_images = 0
241
-
242
- # for message in history[-i-1:]:
243
- # images_this_term.append(message[0][0])
244
-
245
- assert len(images_this_term) > 0, "must have an image"
246
- # image_files = (args.image_file).split(',')
247
- # image = [load_image(f) for f in images_this_term if f]
248
 
249
  all_image_hash = []
250
  all_image_path = []
@@ -288,9 +354,7 @@ def bot(history, temperature, top_p, max_output_tokens):
288
 
289
  image_tensor = torch.stack(image_tensor)
290
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
291
- # if our_chatbot.model.config.mm_use_im_start_end:
292
- # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
293
- # else:
294
  inp = text
295
  inp = image_token + "\n" + inp
296
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
@@ -298,13 +362,6 @@ def bot(history, temperature, top_p, max_output_tokens):
298
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
299
  prompt = our_chatbot.conversation.get_prompt()
300
 
301
- # input_ids = (
302
- # tokenizer_image_token(
303
- # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
304
- # )
305
- # .unsqueeze(0)
306
- # .to(our_chatbot.model.device)
307
- # )
308
  input_ids = tokenizer_image_token(
309
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
310
  ).unsqueeze(0).to(our_chatbot.model.device)
@@ -318,9 +375,7 @@ def bot(history, temperature, top_p, max_output_tokens):
318
  stopping_criteria = KeywordsStoppingCriteria(
319
  keywords, our_chatbot.tokenizer, input_ids
320
  )
321
- # streamer = TextStreamer(
322
- # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
323
- # )
324
  streamer = TextIteratorStreamer(
325
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
326
  )
@@ -328,27 +383,6 @@ def bot(history, temperature, top_p, max_output_tokens):
328
  print(input_ids.device)
329
  print(image_tensor.device)
330
 
331
- # with torch.inference_mode():
332
- # output_ids = our_chatbot.model.generate(
333
- # input_ids,
334
- # images=image_tensor,
335
- # do_sample=True,
336
- # temperature=0.7,
337
- # top_p=1.0,
338
- # max_new_tokens=4096,
339
- # streamer=streamer,
340
- # use_cache=False,
341
- # stopping_criteria=[stopping_criteria],
342
- # )
343
-
344
- # outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
345
- # if outputs.endswith(stop_str):
346
- # outputs = outputs[: -len(stop_str)]
347
- # our_chatbot.conversation.messages[-1][-1] = outputs
348
-
349
- # history[-1] = [text, outputs]
350
-
351
- # return history
352
  generate_kwargs = dict(
353
  inputs=input_ids,
354
  streamer=streamer,
@@ -367,13 +401,12 @@ def bot(history, temperature, top_p, max_output_tokens):
367
  outputs = []
368
  for stream_token in streamer:
369
  outputs.append(stream_token)
370
- # print("### stream_token",stream_token)
371
- # our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
372
  history[-1] = [text, "".join(outputs)]
373
  yield history
374
  our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
375
- print("### turn end history", history)
376
- print("### turn end conv",our_chatbot.conversation)
377
 
378
  with open(get_conv_log_filename(), "a") as fout:
379
  data = {
@@ -637,17 +670,25 @@ with gr.Blocks(
637
  gr.Markdown(learn_more_markdown)
638
  gr.Markdown(bibtext)
639
 
640
- chat_msg = chat_input.submit(
641
- add_message, [chatbot, chat_input], [chatbot, chat_input]
642
- )
643
- bot_msg = chat_msg.then(bot, [chatbot,temperature, top_p, max_output_tokens], chatbot, api_name="bot_response")
644
- bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
645
 
646
  # chatbot.like(print_like_dislike, None, None)
647
  clear_btn.click(
648
  fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
649
  )
650
 
 
 
 
 
 
 
 
 
 
651
 
652
  demo.queue()
653
 
@@ -678,5 +719,5 @@ if __name__ == "__main__":
678
  model_name = get_model_name_from_path(args.model_path)
679
  tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
680
  model=model.to(torch.device('cuda'))
681
- our_chatbot = None
682
  demo.launch()
 
62
 
63
  external_log_dir = "./logs"
64
  LOGDIR = external_log_dir
65
+ VOTEDIR = "./votes"
66
 
67
 
68
  def install_gradio_4_35_0():
 
88
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
89
  return name
90
 
91
+ def get_conv_vote_filename():
92
+ t = datetime.datetime.now()
93
+ name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
94
+ if not os.path.isfile(name):
95
+ os.makedirs(os.path.dirname(name), exist_ok=True)
96
+ return name
97
+
98
+ def vote_last_response(state, vote_type, model_selector):
99
+ with open(get_conv_vote_filename(), "a") as fout:
100
+ data = {
101
+ "type": vote_type,
102
+ "model": model_selector,
103
+ "state": state,
104
+ }
105
+ fout.write(json.dumps(data) + "\n")
106
+ api.upload_file(
107
+ path_or_fileobj=get_conv_vote_filename(),
108
+ path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
109
+ repo_id=repo_name,
110
+ repo_type="dataset")
111
+
112
+
113
+ def upvote_last_response(state):
114
+ vote_last_response(state, "upvote", "Pangea-7b")
115
+ gr.Info("Thank you for your voting!")
116
+ return state
117
+
118
+ def downvote_last_response(state):
119
+ vote_last_response(state, "downvote", "Pangea-7b")
120
+ gr.Info("Thank you for your voting!")
121
+ return state
122
+
123
  class InferenceDemo(object):
124
  def __init__(
125
  self, args, model_path, tokenizer, model, image_processor, context_len
 
158
  self.conversation = conv_templates[args.conv_mode].copy()
159
  self.num_frames = args.num_frames
160
 
161
+ class ChatSessionManager:
162
+ def __init__(self):
163
+ self.chatbot_instance = None
164
+
165
+ def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
166
+ self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
167
+ print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
168
+
169
+ def reset_chatbot(self):
170
+ self.chatbot_instance = None
171
+
172
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
173
+ if self.chatbot_instance is None:
174
+ self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
175
+ return self.chatbot_instance
176
+
177
 
178
  def is_valid_video_filename(name):
179
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
227
  return image
228
 
229
 
 
 
 
 
 
 
 
230
  def clear_response(history):
231
  for index_conv in range(1, len(history)):
232
  # loop until get a text response from our model.
 
237
  history = history[:-index_conv]
238
  return history, question
239
 
240
+ chat_manager = ChatSessionManager()
241
+
242
+
243
+ def clear_history(history):
244
+ chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
245
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
246
+ return None
247
 
 
 
248
 
249
 
250
  def add_message(history, message):
251
+ global chat_image_num
252
+ if not history:
253
+ history = []
254
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
255
+ chat_image_num = 0
256
+
257
+ if len(message["files"]) <= 1:
258
+ for x in message["files"]:
259
+ history.append(((x,), None))
260
+ chat_image_num += 1
261
+ if chat_image_num > 1:
262
+ history = []
263
+ chat_manager.reset_chatbot()
264
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
265
+ chat_image_num = 0
266
+ for x in message["files"]:
267
+ history.append(((x,), None))
268
+ chat_image_num += 1
269
+
270
+ if message["text"] is not None:
271
+ history.append((message["text"], None))
272
+
273
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
274
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
275
+ else:
276
+ for x in message["files"]:
277
+ history.append(((x,), None))
278
+ if message["text"] is not None:
279
+ history.append((message["text"], None))
280
 
281
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
 
 
 
 
282
 
283
 
284
  @spaces.GPU
285
  def bot(history, temperature, top_p, max_output_tokens):
286
+ our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
287
+ print(f"### Chatbot instance ID: {id(our_chatbot)}")
288
  text = history[-1][0]
289
  images_this_term = []
290
  text_this_term = ""
291
+
292
  num_new_images = 0
293
+ previous_image = False
294
  for i, message in enumerate(history[:-1]):
295
  if type(message[0]) is tuple:
296
+ if previous_image:
297
+ gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.")
298
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
299
+ return None
300
+
301
  images_this_term.append(message[0][0])
302
  if is_valid_video_filename(message[0][0]):
 
303
  raise ValueError("Video is not supported")
304
  num_new_images += our_chatbot.num_frames
305
  elif is_valid_image_filename(message[0][0]):
 
307
  num_new_images += 1
308
  else:
309
  raise ValueError("Invalid image file")
310
+ previous_image = True
311
  else:
312
  num_new_images = 0
313
+ previous_image = False
 
 
 
 
 
 
314
 
315
  all_image_hash = []
316
  all_image_path = []
 
354
 
355
  image_tensor = torch.stack(image_tensor)
356
  image_token = DEFAULT_IMAGE_TOKEN * num_new_images
357
+
 
 
358
  inp = text
359
  inp = image_token + "\n" + inp
360
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
 
362
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
363
  prompt = our_chatbot.conversation.get_prompt()
364
 
 
 
 
 
 
 
 
365
  input_ids = tokenizer_image_token(
366
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
367
  ).unsqueeze(0).to(our_chatbot.model.device)
 
375
  stopping_criteria = KeywordsStoppingCriteria(
376
  keywords, our_chatbot.tokenizer, input_ids
377
  )
378
+
 
 
379
  streamer = TextIteratorStreamer(
380
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
381
  )
 
383
  print(input_ids.device)
384
  print(image_tensor.device)
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  generate_kwargs = dict(
387
  inputs=input_ids,
388
  streamer=streamer,
 
401
  outputs = []
402
  for stream_token in streamer:
403
  outputs.append(stream_token)
404
+
 
405
  history[-1] = [text, "".join(outputs)]
406
  yield history
407
  our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
408
+ # print("### turn end history", history)
409
+ # print("### turn end conv",our_chatbot.conversation)
410
 
411
  with open(get_conv_log_filename(), "a") as fout:
412
  data = {
 
670
  gr.Markdown(learn_more_markdown)
671
  gr.Markdown(bibtext)
672
 
673
+ chat_input.submit(
674
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
675
+ ).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
676
+
 
677
 
678
  # chatbot.like(print_like_dislike, None, None)
679
  clear_btn.click(
680
  fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
681
  )
682
 
683
+ upvote_btn.click(
684
+ fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
685
+ )
686
+
687
+
688
+ downvote_btn.click(
689
+ fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response"
690
+ )
691
+
692
 
693
  demo.queue()
694
 
 
719
  model_name = get_model_name_from_path(args.model_path)
720
  tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
721
  model=model.to(torch.device('cuda'))
722
+ chat_image_num = 0
723
  demo.launch()