pvanand commited on
Commit
d9309f4
1 Parent(s): bb33281

add llm endpoint

Browse files
Files changed (1) hide show
  1. main.py +92 -2
main.py CHANGED
@@ -1,13 +1,15 @@
1
- from fastapi import FastAPI, HTTPException, Query, Path
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel, Field
4
- from typing import List
5
  import json
6
  import os
7
  import logging
8
  from txtai.embeddings import Embeddings
9
  import pandas as pd
10
  import glob
 
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -19,6 +21,8 @@ app = FastAPI(
19
  version="1.0.0"
20
  )
21
 
 
 
22
  # Enable CORS
23
  app.add_middleware(
24
  CORSMiddleware,
@@ -138,6 +142,92 @@ def check_and_index_csv_files():
138
  else:
139
  logger.info(f"Index already exists for: {csv_file}")
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  @app.on_event("startup")
142
  async def startup_event():
143
  check_and_index_csv_files()
 
1
+ from fastapi import FastAPI, HTTPException, Query, Path, Header, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel, Field
4
+ from typing import List, Optional, Dict
5
  import json
6
  import os
7
  import logging
8
  from txtai.embeddings import Embeddings
9
  import pandas as pd
10
  import glob
11
+ import uuid
12
+ import httpx
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
 
21
  version="1.0.0"
22
  )
23
 
24
+ CHAT_AUTH_KEY = os.environ.get("CHAT_AUTH_KEY", "default_secret_key")
25
+
26
  # Enable CORS
27
  app.add_middleware(
28
  CORSMiddleware,
 
142
  else:
143
  logger.info(f"Index already exists for: {csv_file}")
144
 
145
+
146
+ # ... [Previous code for DocumentRequest, QueryRequest, save_embeddings, load_embeddings, create_index, query_index, process_csv_file, check_and_index_csv_files remains the same]
147
+
148
+ class ChatRequest(BaseModel):
149
+ query: str = Field(..., description="The user's query")
150
+ index_id: str = Field(..., description="Unique identifier for the index to query")
151
+ conversation_id: Optional[str] = Field(None, description="Unique identifier for the conversation")
152
+ model_id: str = Field(..., description="Identifier for the LLM model to use")
153
+ user_id: str = Field(..., description="Unique identifier for the user")
154
+
155
+ async def get_api_key(x_api_key: str = Header(...)) -> str:
156
+ if x_api_key != CHAT_AUTH_KEY:
157
+ raise HTTPException(status_code=403, detail="Invalid API key")
158
+ return x_api_key
159
+
160
+ async def make_llm_request(api_key: str, llm_request: Dict[str, str]) -> Dict:
161
+ """
162
+ Make a request to the LLM service.
163
+ """
164
+ try:
165
+ async with httpx.AsyncClient() as client:
166
+ response = await client.post(
167
+ "https://pvanand-audio-chat.hf.space/llm-agent",
168
+ headers={
169
+ "accept": "application/json",
170
+ "X-API-Key": api_key,
171
+ "Content-Type": "application/json"
172
+ },
173
+ json=llm_request
174
+ )
175
+
176
+ if response.status_code != 200:
177
+ raise HTTPException(status_code=response.status_code, detail="Error from LLM service")
178
+
179
+ return response.json()
180
+ except httpx.HTTPError as e:
181
+ logger.error(f"HTTP error occurred while making LLM request: {str(e)}")
182
+ raise HTTPException(status_code=500, detail=f"HTTP error occurred while making LLM request: {str(e)}")
183
+ except Exception as e:
184
+ logger.error(f"Unexpected error occurred while making LLM request: {str(e)}")
185
+ raise HTTPException(status_code=500, detail=f"Unexpected error occurred while making LLM request: {str(e)}")
186
+
187
+
188
+ @app.post("/rag-chat/", response_model=dict, tags=["Chat"])
189
+ async def chat(request: ChatRequest, api_key: str = Depends(get_api_key)):
190
+ """
191
+ Chat endpoint that uses embeddings search and LLM for response generation.
192
+ """
193
+ try:
194
+ # Load embeddings for the specified index
195
+ document_list = load_embeddings(request.index_id)
196
+
197
+ # Perform embeddings search
198
+ search_results = embeddings.search(request.query, 5) # Get top 5 relevant results
199
+ context = "\n".join([document_list[idx[0]] for idx in search_results])
200
+
201
+ # Create RAG prompt
202
+ rag_prompt = f"please answer the user's question:\n\nUser's question:{request.query} Based on the following context, \n\nContext:\n{context} \n\nAnswer:"
203
+
204
+ rag_system_prompt = "You are a helpful assistant tasked with providing answers from the given context"
205
+ # Generate conversation_id if not provided
206
+ conversation_id = request.conversation_id or str(uuid.uuid4())
207
+
208
+ # Prepare the request for the LLM service
209
+ llm_request = {
210
+ "prompt": rag_prompt,
211
+ "system_message": rag_system_prompt,
212
+ "model_id": request.model_id,
213
+ "conversation_id": conversation_id,
214
+ "user_id": request.user_id
215
+ }
216
+
217
+ # Make request to LLM service
218
+ llm_response = await make_llm_request(api_key, llm_request)
219
+
220
+ logger.info(f"Chat response generated successfully for user: {request.user_id}")
221
+ return {
222
+ "response": llm_response,
223
+ "conversation_id": conversation_id
224
+ }
225
+
226
+ except Exception as e:
227
+ logger.error(f"Error in chat endpoint: {str(e)}")
228
+ raise HTTPException(status_code=500, detail=f"Error in chat endpoint: {str(e)}")
229
+
230
+
231
  @app.on_event("startup")
232
  async def startup_event():
233
  check_and_index_csv_files()