Ilyas KHIAT commited on
Commit
79e7580
·
1 Parent(s): fde2e8b

style and ton optionnal

Browse files
Files changed (1) hide show
  1. main.py +16 -4
main.py CHANGED
@@ -69,8 +69,8 @@ app = FastAPI(dependencies=[Depends(api_key_auth)])
69
  # return response
70
 
71
  class StyleWriter(BaseModel):
72
- style: str
73
- tonality: str
74
 
75
  models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
76
 
@@ -212,8 +212,20 @@ def generate_answer(user_input: UserInput):
212
  prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
213
  answer = generate_response_via_langchain(prompt, model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
214
  else:
215
- prompt_formated = prompt_reformatting(template_prompt,context,prompt,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,enterprise_name=getattr(user_input,"marque",""))
216
- answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  if user_input.stream:
219
  return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")
 
69
  # return response
70
 
71
  class StyleWriter(BaseModel):
72
+ style: Optional[str] = "neutral"
73
+ tonality: Optional[str] = "formal"
74
 
75
  models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]
76
 
 
212
  prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
213
  answer = generate_response_via_langchain(prompt, model=getattr(user_input,"model","gpt-4o"),stream=user_input.stream,context = context , messages=user_input.messages,template=template_prompt,enterprise_name=getattr(user_input,"marque",""))
214
  else:
215
+ prompt_formated = prompt_reformatting(template_prompt,
216
+ context,
217
+ prompt,
218
+ style=getattr(user_input.style_tonality,"style","neutral"),
219
+ tonality=getattr(user_input.style_tonality,"tonality","formal"),
220
+ enterprise_name=getattr(user_input,"marque",""))
221
+
222
+ answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),
223
+ stream=user_input.stream,context = context ,
224
+ messages=user_input.messages,
225
+ style=getattr(user_input.style_tonality,"style","neutral"),
226
+ tonality=getattr(user_input.style_tonality,"tonality","formal"),
227
+ template=template_prompt,
228
+ enterprise_name=getattr(user_input,"marque",""))
229
 
230
  if user_input.stream:
231
  return StreamingResponse(stream_generator(answer,prompt_formated), media_type="application/json")