j commited on
Commit
c1f878d
·
1 Parent(s): 4bbc7e3

Changed the chatbot message format to use strings instead of dictionaries

Browse files
Files changed (1) hide show
  1. demo/web_demo_audio.py +54 -5
demo/web_demo_audio.py CHANGED
@@ -135,20 +135,69 @@ def _launch_demo(args):
135
 
136
  task_history = gr.State([])
137
 
138
- # Update event handlers for new input components
139
  def process_input(text, audio, chatbot, history):
 
140
  content = []
 
 
141
  if audio is not None:
142
  content.append({'type': 'audio', 'audio_url': audio})
 
 
143
  if text:
144
  content.append({'type': 'text', 'text': text})
 
145
 
146
  history.append({"role": "user", "content": content})
147
- chatbot.append([
148
- {"text": text, "audio": audio},
149
- None
150
- ])
151
  return "", None, chatbot, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  submit_btn.click(
154
  fn=process_input,
 
135
 
136
  task_history = gr.State([])
137
 
 
138
  def process_input(text, audio, chatbot, history):
139
+ """Process input with correct message formatting for Chatbot."""
140
  content = []
141
+ message_text = []
142
+
143
  if audio is not None:
144
  content.append({'type': 'audio', 'audio_url': audio})
145
+ message_text.append(f"[Audio file uploaded]")
146
+
147
  if text:
148
  content.append({'type': 'text', 'text': text})
149
+ message_text.append(text)
150
 
151
  history.append({"role": "user", "content": content})
152
+ # Format message for chatbot as a string instead of dict
153
+ chatbot.append([" ".join(message_text), None])
 
 
154
  return "", None, chatbot, history
155
+
156
+ def predict(chatbot, task_history):
157
+ """Generate a response from the model."""
158
+ print(f"{task_history=}")
159
+ print(f"{chatbot=}")
160
+ text = processor.apply_chat_template(task_history, add_generation_prompt=True, tokenize=False)
161
+ audios = []
162
+ for message in task_history:
163
+ if isinstance(message["content"], list):
164
+ for ele in message["content"]:
165
+ if ele["type"] == "audio":
166
+ audios.append(
167
+ librosa.load(ele['audio_url'], sr=processor.feature_extractor.sampling_rate)[0]
168
+ )
169
+
170
+ if len(audios)==0:
171
+ audios=None
172
+ print(f"{text=}")
173
+ print(f"{audios=}")
174
+ inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
175
+ inputs["input_ids"] = inputs.input_ids.cuda()
176
+
177
+ generate_ids = model.generate(**inputs, max_length=256)
178
+ generate_ids = generate_ids[:, inputs.input_ids.size(1):]
179
+
180
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
181
+ print(f"{response=}")
182
+ task_history.append({'role': 'assistant',
183
+ 'content': response})
184
+ chatbot.append((None, response))
185
+ return chatbot, task_history
186
+
187
+ # Update event handlers for new input components
188
+ # def process_input(text, audio, chatbot, history):
189
+ # content = []
190
+ # if audio is not None:
191
+ # content.append({'type': 'audio', 'audio_url': audio})
192
+ # if text:
193
+ # content.append({'type': 'text', 'text': text})
194
+ #
195
+ # history.append({"role": "user", "content": content})
196
+ # chatbot.append([
197
+ # {"text": text, "audio": audio},
198
+ # None
199
+ # ])
200
+ # return "", None, chatbot, history
201
 
202
  submit_btn.click(
203
  fn=process_input,