Spaces:
Running
Running
Test v2
Browse files
main.py
CHANGED
@@ -226,7 +226,6 @@ async def stream_generator(response, prompt, info_memoire):
|
|
226 |
raise HTTPException(status_code=504, detail="Stream timed out")
|
227 |
|
228 |
|
229 |
-
|
230 |
@app.post("/generate-answer/")
|
231 |
def generate_answer(user_input: UserInput):
|
232 |
try:
|
@@ -289,6 +288,80 @@ def generate_answer(user_input: UserInput):
|
|
289 |
|
290 |
except Exception as e:
|
291 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
@app.get("/models")
|
294 |
def get_models():
|
|
|
226 |
raise HTTPException(status_code=504, detail="Stream timed out")
|
227 |
|
228 |
|
|
|
229 |
@app.post("/generate-answer/")
|
230 |
def generate_answer(user_input: UserInput):
|
231 |
try:
|
|
|
288 |
|
289 |
except Exception as e:
|
290 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
async def stream_generator2(response, prompt, info_memoire):
|
295 |
+
async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
|
296 |
+
try:
|
297 |
+
async for chunk in response:
|
298 |
+
if isinstance(chunk, bytes):
|
299 |
+
yield chunk.decode('utf-8') # Convert bytes to str if needed
|
300 |
+
except asyncio.TimeoutError:
|
301 |
+
raise HTTPException(status_code=504, detail="Stream timed out")
|
302 |
+
|
303 |
+
@app.post("/v2/generate-answer/")
|
304 |
+
def generate_answer(user_input: UserInput):
|
305 |
+
try:
|
306 |
+
print(user_input)
|
307 |
+
|
308 |
+
prompt = user_input.prompt
|
309 |
+
enterprise_id = user_input.enterprise_id
|
310 |
+
|
311 |
+
template_prompt = base_template
|
312 |
+
|
313 |
+
context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id)
|
314 |
+
|
315 |
+
#final_prompt_simplified = prompt_formatting(prompt,template,context)
|
316 |
+
infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id",""))
|
317 |
+
if infos_added_to_kb:
|
318 |
+
prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item']
|
319 |
+
else:
|
320 |
+
infos_added_to_kb = {}
|
321 |
+
|
322 |
+
if not context:
|
323 |
+
context = ""
|
324 |
+
|
325 |
+
if user_input.style_tonality is None:
|
326 |
+
prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
|
327 |
+
answer = generate_response_via_langchain(prompt,
|
328 |
+
model=getattr(user_input,"model","gpt-4o"),
|
329 |
+
stream=user_input.stream,context = context ,
|
330 |
+
messages=user_input.messages,
|
331 |
+
template=template_prompt,
|
332 |
+
enterprise_name=getattr(user_input,"marque",""),
|
333 |
+
enterprise_id=enterprise_id,
|
334 |
+
index=index)
|
335 |
+
else:
|
336 |
+
prompt_formated = prompt_reformatting(template_prompt,
|
337 |
+
context,
|
338 |
+
prompt,
|
339 |
+
style=getattr(user_input.style_tonality,"style","neutral"),
|
340 |
+
tonality=getattr(user_input.style_tonality,"tonality","formal"),
|
341 |
+
enterprise_name=getattr(user_input,"marque",""))
|
342 |
+
|
343 |
+
answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),
|
344 |
+
stream=user_input.stream,context = context ,
|
345 |
+
messages=user_input.messages,
|
346 |
+
style=getattr(user_input.style_tonality,"style","neutral"),
|
347 |
+
tonality=getattr(user_input.style_tonality,"tonality","formal"),
|
348 |
+
template=template_prompt,
|
349 |
+
enterprise_name=getattr(user_input,"marque",""),
|
350 |
+
enterprise_id=enterprise_id,
|
351 |
+
index=index)
|
352 |
+
|
353 |
+
if user_input.stream:
|
354 |
+
return StreamingResponse(stream_generator2(answer,prompt_formated,infos_added_to_kb), media_type="application/json")
|
355 |
+
|
356 |
+
return {
|
357 |
+
"prompt": prompt_formated,
|
358 |
+
"answer": answer,
|
359 |
+
"context": context,
|
360 |
+
"info_memoire": infos_added_to_kb
|
361 |
+
}
|
362 |
+
|
363 |
+
except Exception as e:
|
364 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
365 |
|
366 |
@app.get("/models")
|
367 |
def get_models():
|