jonas commited on
Commit
28d4a09
·
verified ·
1 Parent(s): 3da458d

Upload app.py

Browse files

# Implement Real-Time Streaming for Chat Responses

## Description
This PR introduces real-time streaming functionality to our chat interface., aiming to enhance the user experience by providing immediate, token-by-token responses.

## Changes
- Enabled streaming in the HuggingFaceEndpoint configuration
- Implemented an asynchronous streaming process using `astream()`
- Modified the chat function to yield partial results in real-time
- Updated Gradio setup to support streaming responses (set queue as False)

## Expected Behavior
- Responses should start appearing immediately after a question is asked
- Text should stream in smoothly, word by word or token by token
- The final response should be identical to the non-streaming version

## Technical Details
Key components of the implementation:
1. **Streaming Callback**: Implemented `StreamingStdOutCallbackHandler` for real-time token processing.
2. **LLM Configuration**: Added `streaming=True` to `HuggingFaceEndpoint` setup.
3. **Asynchronous Streaming**: Created `process_stream()` function to handle token-by-token response generation.
4. **Real-Time Updates**: Modified main loop to yield updates as they become available.

Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -217,29 +217,37 @@ async def chat(query,history,sources,reports,subtype,year):
217
 
218
  ##-----------------------getting inference endpoints------------------------------
219
 
 
220
  callback = StreamingStdOutCallbackHandler()
221
 
 
222
  llm_qa = HuggingFaceEndpoint(
223
  endpoint_url=model_config.get('reader', 'ENDPOINT'),
224
  max_new_tokens=512,
225
  repetition_penalty=1.03,
226
  timeout=70,
227
  huggingfacehub_api_token=HF_token,
228
- streaming=True,
229
- callbacks=[callback]
230
  )
231
 
 
232
  chat_model = ChatHuggingFace(llm=llm_qa)
233
 
 
234
  docs_html = []
235
  for i, d in enumerate(context_retrieved, 1):
236
  docs_html.append(make_html_source(d, i))
237
  docs_html = "".join(docs_html)
238
 
 
239
  answer_yet = ""
240
 
 
241
  async def process_stream():
242
- nonlocal answer_yet
 
 
243
  async for chunk in chat_model.astream(messages):
244
  token = chunk.content
245
  answer_yet += token
@@ -247,9 +255,10 @@ async def chat(query,history,sources,reports,subtype,year):
247
  history[-1] = (query, parsed_answer)
248
  yield [tuple(x) for x in history], docs_html
249
 
 
250
  async for update in process_stream():
251
  yield update
252
-
253
  # #callbacks = [StreamingStdOutCallbackHandler()]
254
  # llm_qa = HuggingFaceEndpoint(
255
  # endpoint_url= model_config.get('reader','ENDPOINT'),
@@ -508,11 +517,13 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
508
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
509
  (textbox
510
  .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
 
511
  .then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
512
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
513
 
514
  (examples_hidden
515
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
 
516
  .then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
517
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
518
  )
 
217
 
218
  ##-----------------------getting inference endpoints------------------------------
219
 
220
+ # Set up the streaming callback handler
221
  callback = StreamingStdOutCallbackHandler()
222
 
223
+ # Initialize the HuggingFaceEndpoint with streaming enabled
224
  llm_qa = HuggingFaceEndpoint(
225
  endpoint_url=model_config.get('reader', 'ENDPOINT'),
226
  max_new_tokens=512,
227
  repetition_penalty=1.03,
228
  timeout=70,
229
  huggingfacehub_api_token=HF_token,
230
+ streaming=True, # Enable streaming for real-time token generation
231
+ callbacks=[callback] # Add the streaming callback handler
232
  )
233
 
234
+ # Create a ChatHuggingFace instance with the streaming-enabled endpoint
235
  chat_model = ChatHuggingFace(llm=llm_qa)
236
 
237
+ # Prepare the HTML for displaying source documents
238
  docs_html = []
239
  for i, d in enumerate(context_retrieved, 1):
240
  docs_html.append(make_html_source(d, i))
241
  docs_html = "".join(docs_html)
242
 
243
+ # Initialize the variable to store the accumulated answer
244
  answer_yet = ""
245
 
246
+ # Define an asynchronous generator function to process the streaming response
247
  async def process_stream():
248
+ # Without nonlocal, Python would create a new local variable answer_yet inside process_stream(), instead of modifying the one from the outer scope.
249
+ nonlocal answer_yet # Use the outer scope's answer_yet variable
250
+ # Iterate over the streaming response chunks
251
  async for chunk in chat_model.astream(messages):
252
  token = chunk.content
253
  answer_yet += token
 
255
  history[-1] = (query, parsed_answer)
256
  yield [tuple(x) for x in history], docs_html
257
 
258
+ # Stream the response updates
259
  async for update in process_stream():
260
  yield update
261
+
262
  # #callbacks = [StreamingStdOutCallbackHandler()]
263
  # llm_qa = HuggingFaceEndpoint(
264
  # endpoint_url= model_config.get('reader','ENDPOINT'),
 
517
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
518
  (textbox
519
  .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
520
+ # queue must be set as False (default) so the process is not waiting for another to be finished
521
  .then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
522
  .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
523
 
524
  (examples_hidden
525
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
526
+ # queue must be set as False (default) so the process is not waiting for another to be finished
527
  .then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
528
  .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
529
  )