Ilyas KHIAT commited on
Commit
9a4c626
·
1 Parent(s): 04c9a9c

style and ton

Browse files
Files changed (2) hide show
  1. main.py +9 -2
  2. rag.py +4 -4
main.py CHANGED
@@ -35,17 +35,24 @@ index = pc.Index(index_name)
35
 
36
  app = FastAPI()
37
 
 
 
 
38
 
39
  class UserInput(BaseModel):
40
  prompt: str
41
  enterprise_id: str
42
  stream: Optional[bool] = False
43
  messages: Optional[list[dict]] = []
 
 
44
 
45
  class EnterpriseData(BaseModel):
46
  name: str
47
  id: Optional[str] = None
48
 
 
 
49
  tasks = []
50
 
51
  @app.get("/")
@@ -124,7 +131,7 @@ async def stream_generator(response):
124
  async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
125
  try:
126
  async for chunk in response:
127
- yield {"content":chunk}
128
  except asyncio.TimeoutError:
129
  raise HTTPException(status_code=504, detail="Stream timed out")
130
 
@@ -139,7 +146,7 @@ def generate_answer(user_input: UserInput):
139
  if not context:
140
  context = "No context found"
141
 
142
- answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages)
143
 
144
  if user_input.stream:
145
  return StreamingResponse(answer, media_type="application/json")
 
35
 
36
  app = FastAPI()
37
 
38
+ class StyleWriter(BaseModel):
39
+ style: str
40
+ tonality: str
41
 
42
  class UserInput(BaseModel):
43
  prompt: str
44
  enterprise_id: str
45
  stream: Optional[bool] = False
46
  messages: Optional[list[dict]] = []
47
+ style_tonality: Optional[StyleWriter] = None
48
+
49
 
50
  class EnterpriseData(BaseModel):
51
  name: str
52
  id: Optional[str] = None
53
 
54
+
55
+
56
  tasks = []
57
 
58
  @app.get("/")
 
131
  async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
132
  try:
133
  async for chunk in response:
134
+ yield "random data"
135
  except asyncio.TimeoutError:
136
  raise HTTPException(status_code=504, detail="Stream timed out")
137
 
 
146
  if not context:
147
  context = "No context found"
148
 
149
+ answer = generate_response_via_langchain(prompt, model="gpt-4o",stream=user_input.stream,context = context , messages=user_input.messages,style=user_input.style_tonality.style,tonality=user_input.style_tonality.tonality)
150
 
151
  if user_input.stream:
152
  return StreamingResponse(answer, media_type="application/json")
rag.py CHANGED
@@ -82,9 +82,9 @@ def get_retreive_answer(enterprise_id,prompt,index):
82
  print(e)
83
  return False
84
 
85
- def generate_response_via_langchain(query: str, stream: bool = False, model: str = "gpt-4o-mini",context:str="",messages = []) :
86
  # Define the prompt template
87
- template = "Sachant le context suivant: {context}, et l'historique de la conversation: {messages}, {query}"
88
  prompt = PromptTemplate.from_template(template)
89
 
90
  # Initialize the OpenAI LLM with the specified model
@@ -95,10 +95,10 @@ def generate_response_via_langchain(query: str, stream: bool = False, model: str
95
 
96
  if stream:
97
  # Return a generator that yields streamed responses
98
- return llm_chain.astream({ "query": query, "context": context, "messages": messages})
99
 
100
  # Invoke the LLM chain and return the result
101
- return llm_chain.invoke({ "query": query, "context": context, "messages": messages})
102
 
103
 
104
 
 
82
  print(e)
83
  return False
84
 
85
+ def generate_response_via_langchain(query: str, stream: bool = False, model: str = "gpt-4o-mini",context:str="",messages = [],style:str="formal",tonality:str="neutral"):
86
  # Define the prompt template
87
+ template = "En tant qu'IA experte en marketing, réponds avec un style {style} et une tonalité {tonality} dans ta communcation, sachant le context suivant: {context}, et l'historique de la conversation: {messages}, {query}"
88
  prompt = PromptTemplate.from_template(template)
89
 
90
  # Initialize the OpenAI LLM with the specified model
 
95
 
96
  if stream:
97
  # Return a generator that yields streamed responses
98
+ return llm_chain.astream({ "query": query, "context": context, "messages": messages, "style": style, "tonality": tonality })
99
 
100
  # Invoke the LLM chain and return the result
101
+ return llm_chain.invoke({"query": query, "context": context, "messages": messages, "style": style, "tonality": tonality})
102
 
103
 
104