KingNish commited on
Commit
8fbbb6f
·
verified ·
1 Parent(s): b8389ea

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +18 -19
chatbot.py CHANGED
@@ -191,7 +191,22 @@ def qwen_inference(user_prompt, chat_history):
191
  ]
192
  })
193
 
194
- return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  image_extensions = Image.registered_extensions()
197
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
@@ -204,24 +219,8 @@ client_mistral_nemo = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
204
 
205
  def model_inference(user_prompt, chat_history):
206
  if user_prompt["files"]:
207
- messages = qwen_inference(user_prompt, chat_history)
208
- text = processor.apply_chat_template(
209
- messages, tokenize=False, add_generation_prompt=True
210
- )
211
- image_inputs, video_inputs = process_vision_info(messages)
212
- inputs = processor(
213
- text=[text],
214
- images=image_inputs,
215
- videos=video_inputs,
216
- padding=True,
217
- return_tensors="pt",
218
- ).to("cuda")
219
-
220
- streamer = TextIteratorStreamer(
221
- processor, skip_prompt=True, **{"skip_special_tokens": True}
222
- )
223
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
224
-
225
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
226
  thread.start()
227
 
 
191
  ]
192
  })
193
 
194
+ text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True)
195
+ image_inputs, video_inputs = process_vision_info(messages)
196
+ inputs = processor(
197
+ text=[text],
198
+ images=image_inputs,
199
+ videos=video_inputs,
200
+ padding=True,
201
+ return_tensors="pt",
202
+ ).to("cuda")
203
+
204
+ streamer = TextIteratorStreamer(
205
+ processor, skip_prompt=True, **{"skip_special_tokens": True}
206
+ )
207
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
208
+
209
+ return streamer, generation_kwargs
210
 
211
  image_extensions = Image.registered_extensions()
212
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
 
219
 
220
  def model_inference(user_prompt, chat_history):
221
  if user_prompt["files"]:
222
+ streamer, generation_kwargs = qwen_inference(user_prompt, chat_history)
223
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
  thread.start()
226