cdupland commited on
Commit
0e043cb
·
verified ·
1 Parent(s): 1bd65db
Files changed (1) hide show
  1. main.py +74 -1
main.py CHANGED
@@ -226,7 +226,6 @@ async def stream_generator(response, prompt, info_memoire):
226
  raise HTTPException(status_code=504, detail="Stream timed out")
227
 
228
 
229
-
230
  @app.post("/generate-answer/")
231
  def generate_answer(user_input: UserInput):
232
  try:
@@ -289,6 +288,80 @@ def generate_answer(user_input: UserInput):
289
 
290
  except Exception as e:
291
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  @app.get("/models")
294
  def get_models():
 
226
  raise HTTPException(status_code=504, detail="Stream timed out")
227
 
228
 
 
229
  @app.post("/generate-answer/")
230
  def generate_answer(user_input: UserInput):
231
  try:
 
288
 
289
  except Exception as e:
290
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
291
+
292
+
293
+
294
+ async def stream_generator2(response, prompt, info_memoire):
295
+ async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
296
+ try:
297
+ async for chunk in response:
298
+ if isinstance(chunk, bytes):
299
+ yield chunk.decode('utf-8') # Convert bytes to str if needed
300
+ except asyncio.TimeoutError:
301
+ raise HTTPException(status_code=504, detail="Stream timed out")
302
+
303
+ @app.post("/v2/generate-answer/")
304
+ def generate_answer(user_input: UserInput):
305
+ try:
306
+ print(user_input)
307
+
308
+ prompt = user_input.prompt
309
+ enterprise_id = user_input.enterprise_id
310
+
311
+ template_prompt = base_template
312
+
313
+ context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id)
314
+
315
+ #final_prompt_simplified = prompt_formatting(prompt,template,context)
316
+ infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id",""))
317
+ if infos_added_to_kb:
318
+ prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item']
319
+ else:
320
+ infos_added_to_kb = {}
321
+
322
+ if not context:
323
+ context = ""
324
+
325
+ if user_input.style_tonality is None:
326
+ prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
327
+ answer = generate_response_via_langchain(prompt,
328
+ model=getattr(user_input,"model","gpt-4o"),
329
+ stream=user_input.stream,context = context ,
330
+ messages=user_input.messages,
331
+ template=template_prompt,
332
+ enterprise_name=getattr(user_input,"marque",""),
333
+ enterprise_id=enterprise_id,
334
+ index=index)
335
+ else:
336
+ prompt_formated = prompt_reformatting(template_prompt,
337
+ context,
338
+ prompt,
339
+ style=getattr(user_input.style_tonality,"style","neutral"),
340
+ tonality=getattr(user_input.style_tonality,"tonality","formal"),
341
+ enterprise_name=getattr(user_input,"marque",""))
342
+
343
+ answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),
344
+ stream=user_input.stream,context = context ,
345
+ messages=user_input.messages,
346
+ style=getattr(user_input.style_tonality,"style","neutral"),
347
+ tonality=getattr(user_input.style_tonality,"tonality","formal"),
348
+ template=template_prompt,
349
+ enterprise_name=getattr(user_input,"marque",""),
350
+ enterprise_id=enterprise_id,
351
+ index=index)
352
+
353
+ if user_input.stream:
354
+ return StreamingResponse(stream_generator2(answer,prompt_formated,infos_added_to_kb), media_type="application/json")
355
+
356
+ return {
357
+ "prompt": prompt_formated,
358
+ "answer": answer,
359
+ "context": context,
360
+ "info_memoire": infos_added_to_kb
361
+ }
362
+
363
+ except Exception as e:
364
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
365
 
366
  @app.get("/models")
367
  def get_models():