mickkhaw's picture
Merge changes from main
bb8f798
raw
history blame
15.6 kB
import json
import operator
from pprint import pprint
from typing import Annotated, List, TypedDict
import chainlit as cl
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.messages import AIMessageChunk, FunctionMessage
from langchain_core.utils.function_calling import convert_to_openai_function
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from utils.tools import send_text_tool
from .db import get_recipes, shortlisted_recipes_to_string
from .graph_chains import (
get_grader_chain,
get_question_type_chain,
get_recipe_url_extractor_chain,
get_selected_recipe,
)
from .retrievers import get_self_retriever
class AgentState(TypedDict):
question: Annotated[str, operator.setitem]
question_type: str
generation: str
documents: List[str]
shortlisted_recipes: List[dict]
selected_recipe: dict
messages: Annotated[list, add_messages]
def generate_workflow(base_llm, power_llm):
def _node_question_triage(state: AgentState):
print("---TRIAGE---")
question = state["question"]
messages = state["messages"]
last_message = messages[-1] if messages else ""
shortlisted_recipes = state.get("shortlisted_recipes")
question_type_chain = get_question_type_chain(base_llm)
question_type_response = question_type_chain.invoke(
{
"question": question,
"context": shortlisted_recipes_to_string(shortlisted_recipes),
"last_message": last_message,
}
)
question_type_response_data = sorted(
[
(question_type_response.send_text, "send_sms"),
(question_type_response.asking_for_recipe_suggestions, "asking_for_recipe_suggestions"),
(question_type_response.referring_to_shortlisted_recipes, "referring_to_shortlisted_recipes"),
(question_type_response.show_specific_recipe, "show_specific_recipe"),
(question_type_response.referring_to_specific_recipe, "referring_to_specific_recipe"),
],
key=lambda x: x[0],
reverse=True,
)
pprint(question_type_response_data)
question_type = question_type_response_data[0][1]
selected_recipe = None
if shortlisted_recipes and question_type_response.specific_recipe_url:
selected_recipe = next(
(r for r in shortlisted_recipes if r["url"] == question_type_response.specific_recipe_url)
)
print("set selected recipe", question_type_response.specific_recipe_url)
return {"question_type": question_type, "selected_recipe": selected_recipe}
async def _node_call_retriever(state: AgentState, config):
print("---RETRIEVE---")
cl_msg = config["configurable"]["cl_msg"]
await cl_msg.stream_token("Searching for recipes matching your criteria ... \n\n")
question = state["question"]
vector_db_chain = get_self_retriever(base_llm)
# Retrieval
documents = vector_db_chain.invoke(question, return_only_outputs=False)
print("WOW: ", vector_db_chain.search_kwargs)
return {"documents": documents, "question": question}
async def _node_grade_recipes(state: AgentState, config):
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
cl_msg = config["configurable"]["cl_msg"]
question = state["question"]
documents = state["documents"]
await cl_msg.stream_token(
f"Evaluating the relevance of {len(documents)} retrieved recipes based on your criteria ... \n\n"
)
retrieval_grader = get_grader_chain(base_llm)
# Score each doc
filtered_docs = []
for d in documents:
grader_output = retrieval_grader.invoke({"question": question, "document": d.page_content})
binary_score = grader_output.binary_score
score = grader_output.integer_score
if binary_score == "yes":
print("---GRADE: DOCUMENT RELEVANT---: ", score, d.metadata["url"])
d.metadata["score"] = score
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---", score, d.metadata["url"])
continue
num_eliminated_docs = len(documents) - len(filtered_docs)
if num_eliminated_docs > 0:
await cl_msg.stream_token(
f"Eliminated {num_eliminated_docs} recipes that were not relevant based on your criteria ... \n\n"
)
return {"documents": filtered_docs, "question": question}
async def _node_generate_response(state: AgentState, config):
"""
Determines whether the retrieved recipes are relevant to the question.
Args:
state (messages): The current state
Returns:
str: A decision for whether the documents are relevant or not
"""
print("--- GENERATING SHORTLIST ---")
question = state["question"]
documents = state["documents"]
# LLM with tool and validation
base_rag_prompt_template = """\
You are a friendly AI assistant. Using the provided context,
please answer the user's question in a friendly, conversational tone.
Based on the context provided, please select the top 3 receipes that best fits criteria
outlined in the question. It doesn't need to be a perfect match but just get the most suitable.
For each option, provide the following information:
1. A brief description of the recipe
2. The URL of the recipe
3. The ratings and number of ratings
Only if question includes a criteria for recipes that are good for a specific occassion, please also provide the occassion(s) of the recipe,
Only if question includes a criteria a type of cuisine, please also provide the cuisines associated with the recipe.
Only if question includes a criteria a type of diet, please also provide the diet(s) associated with the recipe.
If the context is empty, please be careful to note to the user that there are no recipes matching those specific requirements and do NOT provide any other recipes as suggestions.
Context:
{context}
Question:
{question}
"""
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
chain = base_rag_prompt | power_llm
full_response = ""
cl_msg = config["configurable"]["cl_msg"]
async for chunk in chain.astream(
{"question": question, "context": documents},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
if isinstance(chunk, AIMessageChunk):
await cl_msg.stream_token(chunk.content)
full_response += chunk.content
url_extractor = get_recipe_url_extractor_chain(base_llm)
url_extractor_results = url_extractor.invoke({"context": full_response})
shortlisted_recipes = None
if isinstance(url_extractor_results.urls, list) and len(url_extractor_results.urls):
shortlisted_recipes = get_recipes(url_extractor_results.urls)
return {
"documents": documents,
"question": question,
"shortlisted_recipes": shortlisted_recipes,
"messages": [full_response],
}
async def _node_shortlist_qa(state: AgentState, config):
print("--- Q&A with SHORTLISTED RECIPES ---")
question = state["question"]
shortlisted_recipes = state["shortlisted_recipes"]
messages = state["messages"]
last_message = messages[-1] if messages else ""
question_type = state["question_type"]
# LLM with tool and validation
base_rag_prompt_template = """\
You are a friendly AI assistant. Using only the provided context,
please answer the user's question in a friendly, conversational tone.
If you don't know the answer based on the context, say you don't know.
Context:
{context}
Last message provided to the user:
{last_message}
Question:
{question}
"""
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
chain = base_rag_prompt | power_llm
full_response = ""
thumbnail_url = ""
cl_msg = config["configurable"]["cl_msg"]
if state["question_type"] == "show_specific_recipe":
selected_recipe = state.get("selected_recipe")
if selected_recipe and selected_recipe.get("thumbnail"):
thumbnail_url = selected_recipe["thumbnail"]
image = cl.Image(url=thumbnail_url, name="thumbnail", display="inline", size="large")
# Attach the image to the message
await cl.Message(
content="",
elements=[image],
).send()
async for chunk in chain.astream(
{
"question": question,
"context": shortlisted_recipes_to_string(shortlisted_recipes),
"last_message": last_message,
},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
if isinstance(chunk, AIMessageChunk):
await cl_msg.stream_token(chunk.content)
full_response += chunk.content
return {
"messages": [full_response],
}
async def _node_single_recipe_qa(state: AgentState, config):
print("--- Q&A with SINGLE RECIPE ---")
question = state["question"]
selected_recipe = state.get("selected_recipe")
messages = state["messages"]
last_message = messages[-1] if messages else ""
# LLM with tool and validation
base_rag_prompt_template = """\
You are a friendly AI assistant. Using only the provided context,
please answer the user's question in a friendly, conversational tone.
If you don't know the answer based on the context, say you don't know.
Context:
{context}
Last message provided to the user:
{last_message}
Question:
{question}
"""
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
power_llm_with_tool = power_llm.bind_functions([convert_to_openai_function(send_text_tool)])
chain = base_rag_prompt | power_llm_with_tool
full_response = ""
cl_msg = config["configurable"]["cl_msg"]
async for chunk in chain.astream(
{"question": question, "context": selected_recipe["text"], "last_message": last_message},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
if isinstance(chunk, AIMessageChunk):
await cl_msg.stream_token(chunk.content)
full_response += chunk.content
return {"messages": [full_response]}
async def _node_send_sms(state: AgentState, config):
print("--- SEND SMS ---")
question = state["question"]
selected_recipe = state.get("selected_recipe")
messages = state["messages"]
last_message = messages[-1] if messages else ""
cl_msg = config["configurable"]["cl_msg"]
# LLM with tool and validation
base_rag_prompt_template = """\
You are a friendly AI assistant.
Using only the provided context and the tool,
please fullfill the user's request to send an SMS text
Context:
{context}
Last message provided to the user:
{last_message}
Question:
{question}
"""
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
# tool_functions =
power_llm_with_tool = power_llm.bind_functions([convert_to_openai_function(send_text_tool)])
chain = base_rag_prompt | power_llm_with_tool
tool_executor = ToolExecutor([send_text_tool])
message = chain.invoke(
{
"question": question,
"context": selected_recipe.get("text") if selected_recipe else "",
"last_message": last_message,
},
)
print("message", message)
tool_arguments = json.loads(message.additional_kwargs["function_call"]["arguments"])
action = ToolInvocation(
tool=message.additional_kwargs["function_call"]["name"],
tool_input=tool_arguments,
)
response = tool_executor.invoke(action)
function_message = FunctionMessage(content=str(response), name=action.tool)
await cl_msg.stream_token(
f"""Sure! I've sent a text to {tool_arguments['number']} with the following: \n\n{tool_arguments['text']}"""
)
return {"messages": [function_message]}
workflow = StateGraph(AgentState)
# Define the nodes
workflow.add_node("triage", _node_question_triage) # retrieve
workflow.add_node("retrieve", _node_call_retriever) # retrieve
workflow.add_node("grade_recipes", _node_grade_recipes) # grade documents
workflow.add_node("generate", _node_generate_response) # generatae
workflow.add_node("shortlist_qa", _node_shortlist_qa) # answer questions about shortlisted recipes
workflow.add_node("single_qa", _node_single_recipe_qa) # answer questions about shortlisted recipes
workflow.add_node("send_sms", _node_send_sms) # answer questions about shortlisted recipes
# Define the edges
def _edge_route_question(state: AgentState):
print("=======EDGE: START =====")
question_type = state["question_type"]
messages = state["messages"]
shortlisted_recipes = state.get("shortlisted_recipes")
selected_recipe = state.get("selected_recipe")
if question_type == "asking_for_recipe_suggestions":
return "retrieve"
if question_type in ["referring_to_shortlisted_recipes", "show_specific_recipe"]:
return "shortlist_qa"
if question_type == "referring_to_specific_recipe" and selected_recipe:
return "single_qa"
if question_type == "send_sms":
return "send_sms"
print("defaulting to shortlist_qa")
return "shortlist_qa"
workflow.add_edge(START, "triage")
workflow.add_conditional_edges(
"triage",
_edge_route_question,
{
"shortlist_qa": "shortlist_qa",
"single_qa": "single_qa",
"retrieve": "retrieve",
"send_sms": "send_sms",
},
)
workflow.add_edge("retrieve", "grade_recipes")
workflow.add_edge("grade_recipes", "generate")
workflow.add_edge("generate", END)
workflow.add_edge("shortlist_qa", END)
workflow.add_edge("single_qa", END)
workflow.add_edge("send_sms", END)
memory = AsyncSqliteSaver.from_conn_string(":memory:")
app = workflow.compile(checkpointer=memory)
return app