ClownRat commited on
Commit
ee3d0a5
β€’
1 Parent(s): fe61520

Update demo.

Browse files
Files changed (1) hide show
  1. app.py +26 -24
app.py CHANGED
@@ -102,6 +102,7 @@ class Chat:
102
  # 2. text preprocess (tag process & generate prompt).
103
  state = self.get_prompt(prompt, state)
104
  prompt = state.get_prompt()
 
105
  input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt')
106
  input_ids = input_ids.unsqueeze(0).to(self.model.device)
107
 
@@ -130,15 +131,13 @@ class Chat:
130
 
131
 
132
  @spaces.GPU(duration=120)
133
- def generate(image, video, first_run, state, state_, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
134
- flag = 1
135
  if not textbox_in:
136
  if len(state_.messages) > 0:
137
  textbox_in = state_.messages[-1][1]
138
  state_.messages.pop(-1)
139
- flag = 0
140
  else:
141
- return "Please enter instruction"
142
 
143
  image = image if image else "none"
144
  video = video if video else "none"
@@ -187,30 +186,34 @@ def generate(image, video, first_run, state, state_, textbox_in, temperature, to
187
  if os.path.exists(video):
188
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
189
 
190
- if flag:
191
- state.append_message(state.roles[0], textbox_in + "\n" + show_images)
192
  state.append_message(state.roles[1], textbox_out)
193
 
194
- return (gr.update(value=image if os.path.exists(image) else None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True),
195
- state.to_gradio_chatbot(), False, state, state_, gr.update(value=None, interactive=True))
 
 
 
 
 
196
 
197
 
198
- def regenerate(state, state_, textbox):
 
199
  state.messages.pop(-1)
200
- state_.messages.pop(-1)
201
- textbox = gr.update(value=None, interactive=True)
202
  if len(state.messages) > 0:
203
- return state, state_, textbox, state.to_gradio_chatbot(), False
204
- return state, state_, textbox, state.to_gradio_chatbot(), True
205
 
206
 
207
  def clear_history(state, state_):
208
  state = conv_templates[conv_mode].copy()
209
  state_ = conv_templates[conv_mode].copy()
210
  return (gr.update(value=None, interactive=True),
211
- gr.update(value=None, interactive=True), \
212
- state.to_gradio_chatbot(), \
213
- True, state, state_, gr.update(value=None, interactive=True))
 
214
 
215
  # BUG of Zero Environment
216
  # 1. The environment is fixed to torch==2.0.1+cu117, gradio>=4.x.x
@@ -230,7 +233,6 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=gr.themes.Default(primar
230
  gr.Markdown(title_markdown)
231
  state = gr.State()
232
  state_ = gr.State()
233
- first_run = gr.State()
234
 
235
  with gr.Row():
236
  with gr.Column(scale=3):
@@ -331,20 +333,20 @@ with gr.Blocks(title='VideoLLaMA 2 πŸ”₯πŸš€πŸ”₯', theme=gr.themes.Default(primar
331
 
332
  submit_btn.click(
333
  generate,
334
- [image, video, first_run, state, state_, textbox, temperature, top_p, max_output_tokens],
335
- [image, video, chatbot, first_run, state, state_, textbox])
336
 
337
  regenerate_btn.click(
338
  regenerate,
339
- [state, state_, textbox],
340
- [state, state_, textbox, chatbot, first_run]).then(
341
  generate,
342
- [image, video, first_run, state, state_, textbox, temperature, top_p, max_output_tokens],
343
- [image, video, chatbot, first_run, state, state_, textbox])
344
 
345
  clear_btn.click(
346
  clear_history,
347
  [state, state_],
348
- [image, video, chatbot, first_run, state, state_, textbox])
349
 
350
  demo.launch()
 
102
  # 2. text preprocess (tag process & generate prompt).
103
  state = self.get_prompt(prompt, state)
104
  prompt = state.get_prompt()
105
+
106
  input_ids = tokenizer_MMODAL_token(prompt, tokenizer, MMODAL_TOKEN_INDEX[modals[0]], return_tensors='pt')
107
  input_ids = input_ids.unsqueeze(0).to(self.model.device)
108
 
 
131
 
132
 
133
  @spaces.GPU(duration=120)
134
+ def generate(image, video, state, state_, textbox_in, temperature, top_p, max_output_tokens, dtype=torch.float16):
 
135
  if not textbox_in:
136
  if len(state_.messages) > 0:
137
  textbox_in = state_.messages[-1][1]
138
  state_.messages.pop(-1)
 
139
  else:
140
+ assert "Please enter instruction"
141
 
142
  image = image if image else "none"
143
  video = video if video else "none"
 
186
  if os.path.exists(video):
187
  show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={video}"></video>'
188
 
189
+ state.append_message(state.roles[0], textbox_in + "\n" + show_images)
 
190
  state.append_message(state.roles[1], textbox_out)
191
 
192
+ # BUG: only support single turn conversation now.
193
+ state_.messages.pop(-1)
194
+ state_.messages.pop(-1)
195
+
196
+ return (gr.update(value=image if os.path.exists(image) else None, interactive=True),
197
+ gr.update(value=video if os.path.exists(video) else None, interactive=True),
198
+ state.to_gradio_chatbot(), state, state_)
199
 
200
 
201
+ def regenerate(state, state_):
202
+ state.messages.pop(-1)
203
  state.messages.pop(-1)
 
 
204
  if len(state.messages) > 0:
205
+ return state.to_gradio_chatbot(), state, state_
206
+ return state.to_gradio_chatbot(), state, state_
207
 
208
 
209
  def clear_history(state, state_):
210
  state = conv_templates[conv_mode].copy()
211
  state_ = conv_templates[conv_mode].copy()
212
  return (gr.update(value=None, interactive=True),
213
+ gr.update(value=None, interactive=True),
214
+ state.to_gradio_chatbot(), state, state_,
215
+ gr.update(value=None, interactive=True))
216
+
217
 
218
  # BUG of Zero Environment
219
  # 1. The environment is fixed to torch==2.0.1+cu117, gradio>=4.x.x
 
233
  gr.Markdown(title_markdown)
234
  state = gr.State()
235
  state_ = gr.State()
 
236
 
237
  with gr.Row():
238
  with gr.Column(scale=3):
 
333
 
334
  submit_btn.click(
335
  generate,
336
+ [image, video, state, state_, textbox, temperature, top_p, max_output_tokens],
337
+ [image, video, chatbot, state, state_])
338
 
339
  regenerate_btn.click(
340
  regenerate,
341
+ [state, state_],
342
+ [chatbot, state, state_]).then(
343
  generate,
344
+ [image, video, state, state_, textbox, temperature, top_p, max_output_tokens],
345
+ [image, video, chatbot, state, state_])
346
 
347
  clear_btn.click(
348
  clear_history,
349
  [state, state_],
350
+ [image, video, chatbot, state, state_, textbox])
351
 
352
  demo.launch()