pvanand commited on
Commit
bc15143
1 Parent(s): d9309f4

add streaming response

Browse files
Files changed (1) hide show
  1. main.py +32 -28
main.py CHANGED
@@ -1,7 +1,8 @@
1
- from fastapi import FastAPI, HTTPException, Query, Path, Header, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from pydantic import BaseModel, Field
4
- from typing import List, Optional, Dict
5
  import json
6
  import os
7
  import logging
@@ -10,7 +11,7 @@ import pandas as pd
10
  import glob
11
  import uuid
12
  import httpx
13
-
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
@@ -157,26 +158,27 @@ async def get_api_key(x_api_key: str = Header(...)) -> str:
157
  raise HTTPException(status_code=403, detail="Invalid API key")
158
  return x_api_key
159
 
160
- async def make_llm_request(api_key: str, llm_request: Dict[str, str]) -> Dict:
161
  """
162
- Make a request to the LLM service.
163
  """
164
  try:
165
  async with httpx.AsyncClient() as client:
166
- response = await client.post(
 
167
  "https://pvanand-audio-chat.hf.space/llm-agent",
168
  headers={
169
- "accept": "application/json",
170
  "X-API-Key": api_key,
171
  "Content-Type": "application/json"
172
  },
173
  json=llm_request
174
- )
175
-
176
- if response.status_code != 200:
177
- raise HTTPException(status_code=response.status_code, detail="Error from LLM service")
178
-
179
- return response.json()
180
  except httpx.HTTPError as e:
181
  logger.error(f"HTTP error occurred while making LLM request: {str(e)}")
182
  raise HTTPException(status_code=500, detail=f"HTTP error occurred while making LLM request: {str(e)}")
@@ -184,9 +186,8 @@ async def make_llm_request(api_key: str, llm_request: Dict[str, str]) -> Dict:
184
  logger.error(f"Unexpected error occurred while making LLM request: {str(e)}")
185
  raise HTTPException(status_code=500, detail=f"Unexpected error occurred while making LLM request: {str(e)}")
186
 
187
-
188
- @app.post("/rag-chat/", response_model=dict, tags=["Chat"])
189
- async def chat(request: ChatRequest, api_key: str = Depends(get_api_key)):
190
  """
191
  Chat endpoint that uses embeddings search and LLM for response generation.
192
  """
@@ -199,29 +200,32 @@ async def chat(request: ChatRequest, api_key: str = Depends(get_api_key)):
199
  context = "\n".join([document_list[idx[0]] for idx in search_results])
200
 
201
  # Create RAG prompt
202
- rag_prompt = f"please answer the user's question:\n\nUser's question:{request.query} Based on the following context, \n\nContext:\n{context} \n\nAnswer:"
203
 
204
- rag_system_prompt = "You are a helpful assistant tasked with providing answers from the given context"
205
  # Generate conversation_id if not provided
206
  conversation_id = request.conversation_id or str(uuid.uuid4())
207
 
208
  # Prepare the request for the LLM service
209
  llm_request = {
210
- "prompt": rag_prompt,
211
- "system_message": rag_system_prompt,
212
  "model_id": request.model_id,
213
  "conversation_id": conversation_id,
214
  "user_id": request.user_id
215
  }
216
 
217
- # Make request to LLM service
218
- llm_response = await make_llm_request(api_key, llm_request)
219
-
220
- logger.info(f"Chat response generated successfully for user: {request.user_id}")
221
- return {
222
- "response": llm_response,
223
- "conversation_id": conversation_id
224
- }
 
 
 
 
225
 
226
  except Exception as e:
227
  logger.error(f"Error in chat endpoint: {str(e)}")
 
1
+ from fastapi import FastAPI, HTTPException, Header, Depends, BackgroundTasks, Query
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, Field
5
+ from typing import List, Optional, Dict, AsyncGenerator
6
  import json
7
  import os
8
  import logging
 
11
  import glob
12
  import uuid
13
  import httpx
14
+ import asyncio
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
 
158
  raise HTTPException(status_code=403, detail="Invalid API key")
159
  return x_api_key
160
 
161
+ async def stream_llm_request(api_key: str, llm_request: Dict[str, str]) -> AsyncGenerator[str, None]:
162
  """
163
+ Make a streaming request to the LLM service.
164
  """
165
  try:
166
  async with httpx.AsyncClient() as client:
167
+ async with client.stream(
168
+ "POST",
169
  "https://pvanand-audio-chat.hf.space/llm-agent",
170
  headers={
171
+ "accept": "text/event-stream",
172
  "X-API-Key": api_key,
173
  "Content-Type": "application/json"
174
  },
175
  json=llm_request
176
+ ) as response:
177
+ if response.status_code != 200:
178
+ raise HTTPException(status_code=response.status_code, detail="Error from LLM service")
179
+
180
+ async for chunk in response.aiter_text():
181
+ yield chunk
182
  except httpx.HTTPError as e:
183
  logger.error(f"HTTP error occurred while making LLM request: {str(e)}")
184
  raise HTTPException(status_code=500, detail=f"HTTP error occurred while making LLM request: {str(e)}")
 
186
  logger.error(f"Unexpected error occurred while making LLM request: {str(e)}")
187
  raise HTTPException(status_code=500, detail=f"Unexpected error occurred while making LLM request: {str(e)}")
188
 
189
+ @app.post("/chat/", response_class=StreamingResponse, tags=["Chat"])
190
+ async def chat(request: ChatRequest, background_tasks: BackgroundTasks, api_key: str = Depends(get_api_key)):
 
191
  """
192
  Chat endpoint that uses embeddings search and LLM for response generation.
193
  """
 
200
  context = "\n".join([document_list[idx[0]] for idx in search_results])
201
 
202
  # Create RAG prompt
203
+ rag_prompt = f"Based on the following context, please answer the user's question:\n\nContext:\n{context}\n\nUser's question: {request.query}\n\nAnswer:"
204
 
 
205
  # Generate conversation_id if not provided
206
  conversation_id = request.conversation_id or str(uuid.uuid4())
207
 
208
  # Prepare the request for the LLM service
209
  llm_request = {
210
+ "prompt": request.query,
211
+ "system_message": rag_prompt,
212
  "model_id": request.model_id,
213
  "conversation_id": conversation_id,
214
  "user_id": request.user_id
215
  }
216
 
217
+ async def response_generator():
218
+ full_response = ""
219
+ async for chunk in stream_llm_request(api_key, llm_request):
220
+ full_response += chunk
221
+ yield chunk
222
+
223
+ # Here you might want to add logic to save the conversation or perform other background tasks
224
+ # For example:
225
+ # background_tasks.add_task(save_conversation, request.user_id, conversation_id, request.query, full_response)
226
+
227
+ logger.info(f"Starting chat response generation for user: {request.user_id}")
228
+ return StreamingResponse(response_generator(), media_type="text/event-stream")
229
 
230
  except Exception as e:
231
  logger.error(f"Error in chat endpoint: {str(e)}")