Maximofn commited on
Commit
664d175
·
1 Parent(s): a73e772

feat(ENDPOINT): :pushpin: Add new summarize endpoint

Browse files
Files changed (2) hide show
  1. README.md +4 -0
  2. app.py +67 -6
README.md CHANGED
@@ -49,6 +49,10 @@ Welcome endpoint that returns a greeting message.
49
 
50
  Endpoint to generate text using the language model.
51
 
 
 
 
 
52
  **Request parameters:**
53
  ```json
54
  {
 
49
 
50
  Endpoint to generate text using the language model.
51
 
52
+ ### POST `/summarize`
53
+
54
+ Endpoint to summarize text using the language model.
55
+
56
  **Request parameters:**
57
  ```json
58
  {
app.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
 
5
 
6
  from langchain_core.messages import HumanMessage, AIMessage
7
  from langgraph.checkpoint.memory import MemorySaver
@@ -42,7 +43,7 @@ except Exception as e:
42
  raise
43
 
44
  # Define the function that calls the model
45
- def call_model(state: MessagesState):
46
  """
47
  Call the model with the given messages
48
 
@@ -54,7 +55,7 @@ def call_model(state: MessagesState):
54
  """
55
  # Convert LangChain messages to chat format
56
  messages = [
57
- {"role": "system", "content": "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."}
58
  ]
59
 
60
  for msg in state["messages"]:
@@ -95,12 +96,26 @@ workflow.add_node("model", call_model)
95
 
96
  # Add memory
97
  memory = MemorySaver()
 
 
 
 
 
 
 
98
  graph_app = workflow.compile(checkpointer=memory)
99
 
100
  # Define the data model for the request
101
  class QueryRequest(BaseModel):
102
  query: str
103
  thread_id: str = "default"
 
 
 
 
 
 
 
104
 
105
  # Create the FastAPI application
106
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge")
@@ -119,8 +134,9 @@ async def generate(request: QueryRequest):
119
 
120
  Args:
121
  request: QueryRequest
122
- query: str
123
- thread_id: str = "default"
 
124
 
125
  Returns:
126
  dict: A dictionary containing the generated text and the thread ID
@@ -132,8 +148,12 @@ async def generate(request: QueryRequest):
132
  # Create the input message
133
  input_messages = [HumanMessage(content=request.query)]
134
 
135
- # Invoke the graph
136
- output = graph_app.invoke({"messages": input_messages}, config)
 
 
 
 
137
 
138
  # Get the model response
139
  response = output["messages"][-1].content
@@ -145,6 +165,47 @@ async def generate(request: QueryRequest):
145
  except Exception as e:
146
  raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  if __name__ == "__main__":
149
  import uvicorn
150
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ from functools import partial
6
 
7
  from langchain_core.messages import HumanMessage, AIMessage
8
  from langgraph.checkpoint.memory import MemorySaver
 
43
  raise
44
 
45
  # Define the function that calls the model
46
+ def call_model(state: MessagesState, system_prompt: str):
47
  """
48
  Call the model with the given messages
49
 
 
55
  """
56
  # Convert LangChain messages to chat format
57
  messages = [
58
+ {"role": "system", "content": system_prompt}
59
  ]
60
 
61
  for msg in state["messages"]:
 
96
 
97
  # Add memory
98
  memory = MemorySaver()
99
+
100
+ # Define the default system prompt
101
+ DEFAULT_SYSTEM_PROMPT = "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."
102
+
103
+ # Use partial to create a version of the function with the default system prompt
104
+ workflow.add_node("model", partial(call_model, system_prompt=DEFAULT_SYSTEM_PROMPT))
105
+
106
  graph_app = workflow.compile(checkpointer=memory)
107
 
108
  # Define the data model for the request
109
  class QueryRequest(BaseModel):
110
  query: str
111
  thread_id: str = "default"
112
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT
113
+
114
+ # Define the model for summary requests
115
+ class SummaryRequest(BaseModel):
116
+ text: str
117
+ thread_id: str = "default"
118
+ max_length: int = 200
119
 
120
  # Create the FastAPI application
121
  app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge")
 
134
 
135
  Args:
136
  request: QueryRequest
137
+ query: str
138
+ thread_id: str = "default"
139
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT
140
 
141
  Returns:
142
  dict: A dictionary containing the generated text and the thread ID
 
148
  # Create the input message
149
  input_messages = [HumanMessage(content=request.query)]
150
 
151
+ # Invoke the graph with custom system prompt
152
+ output = graph_app.invoke(
153
+ {"messages": input_messages},
154
+ config,
155
+ {"model": {"system_prompt": request.system_prompt}}
156
+ )
157
 
158
  # Get the model response
159
  response = output["messages"][-1].content
 
165
  except Exception as e:
166
  raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
167
 
168
+ @app.post("/summarize")
169
+ async def summarize(request: SummaryRequest):
170
+ """
171
+ Endpoint to generate a summary using the language model
172
+
173
+ Args:
174
+ request: SummaryRequest
175
+ text: str - The text to summarize
176
+ thread_id: str = "default"
177
+ max_length: int = 200 - Maximum summary length
178
+
179
+ Returns:
180
+ dict: A dictionary containing the summary and the thread ID
181
+ """
182
+ try:
183
+ # Configure the thread ID
184
+ config = {"configurable": {"thread_id": request.thread_id}}
185
+
186
+ # Create a specific system prompt for summarization
187
+ summary_system_prompt = f"Make a summary of the following text in no more than {request.max_length} words. Keep the most important information and eliminate unnecessary details."
188
+
189
+ # Create the input message
190
+ input_messages = [HumanMessage(content=request.text)]
191
+
192
+ # Invoke the graph with summarization system prompt
193
+ output = graph_app.invoke(
194
+ {"messages": input_messages},
195
+ config,
196
+ {"model": {"system_prompt": summary_system_prompt}}
197
+ )
198
+
199
+ # Get the model response
200
+ response = output["messages"][-1].content
201
+
202
+ return {
203
+ "summary": response,
204
+ "thread_id": request.thread_id
205
+ }
206
+ except Exception as e:
207
+ raise HTTPException(status_code=500, detail=f"Error generating summary: {str(e)}")
208
+
209
  if __name__ == "__main__":
210
  import uvicorn
211
  uvicorn.run(app, host="0.0.0.0", port=7860)