Spaces:
Paused
Paused
import json | |
import operator | |
from pprint import pprint | |
from typing import Annotated, List, TypedDict | |
import chainlit as cl | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain_core.messages import AIMessageChunk, FunctionMessage | |
from langchain_core.utils.function_calling import convert_to_openai_function | |
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver | |
from langgraph.graph import END, START, StateGraph | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolExecutor, ToolInvocation | |
from utils.tools import send_text_tool | |
from .db import get_recipes, shortlisted_recipes_to_string | |
from .graph_chains import ( | |
get_grader_chain, | |
get_question_type_chain, | |
get_recipe_url_extractor_chain, | |
get_selected_recipe, | |
) | |
from .retrievers import get_self_retriever | |
class AgentState(TypedDict): | |
question: Annotated[str, operator.setitem] | |
question_type: str | |
generation: str | |
documents: List[str] | |
shortlisted_recipes: List[dict] | |
selected_recipe: dict | |
messages: Annotated[list, add_messages] | |
def generate_workflow(base_llm, power_llm): | |
def _node_question_triage(state: AgentState): | |
print("---TRIAGE---") | |
question = state["question"] | |
messages = state["messages"] | |
last_message = messages[-1] if messages else "" | |
shortlisted_recipes = state.get("shortlisted_recipes") | |
question_type_chain = get_question_type_chain(base_llm) | |
question_type_response = question_type_chain.invoke( | |
{ | |
"question": question, | |
"context": shortlisted_recipes_to_string(shortlisted_recipes), | |
"last_message": last_message, | |
} | |
) | |
question_type_response_data = sorted( | |
[ | |
(question_type_response.send_text, "send_sms"), | |
(question_type_response.asking_for_recipe_suggestions, "asking_for_recipe_suggestions"), | |
(question_type_response.referring_to_shortlisted_recipes, "referring_to_shortlisted_recipes"), | |
(question_type_response.show_specific_recipe, "show_specific_recipe"), | |
(question_type_response.referring_to_specific_recipe, "referring_to_specific_recipe"), | |
], | |
key=lambda x: x[0], | |
reverse=True, | |
) | |
pprint(question_type_response_data) | |
question_type = question_type_response_data[0][1] | |
selected_recipe = None | |
if shortlisted_recipes and question_type_response.specific_recipe_url: | |
selected_recipe = next( | |
(r for r in shortlisted_recipes if r["url"] == question_type_response.specific_recipe_url) | |
) | |
print("set selected recipe", question_type_response.specific_recipe_url) | |
return {"question_type": question_type, "selected_recipe": selected_recipe} | |
async def _node_call_retriever(state: AgentState, config): | |
print("---RETRIEVE---") | |
cl_msg = config["configurable"]["cl_msg"] | |
await cl_msg.stream_token("Searching for recipes matching your criteria ... \n\n") | |
question = state["question"] | |
vector_db_chain = get_self_retriever(base_llm) | |
# Retrieval | |
documents = vector_db_chain.invoke(question, return_only_outputs=False) | |
print("WOW: ", vector_db_chain.search_kwargs) | |
return {"documents": documents, "question": question} | |
async def _node_grade_recipes(state: AgentState, config): | |
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") | |
cl_msg = config["configurable"]["cl_msg"] | |
question = state["question"] | |
documents = state["documents"] | |
await cl_msg.stream_token( | |
f"Evaluating the relevance of {len(documents)} retrieved recipes based on your criteria ... \n\n" | |
) | |
retrieval_grader = get_grader_chain(base_llm) | |
# Score each doc | |
filtered_docs = [] | |
for d in documents: | |
grader_output = retrieval_grader.invoke({"question": question, "document": d.page_content}) | |
binary_score = grader_output.binary_score | |
score = grader_output.integer_score | |
if binary_score == "yes": | |
print("---GRADE: DOCUMENT RELEVANT---: ", score, d.metadata["url"]) | |
d.metadata["score"] = score | |
filtered_docs.append(d) | |
else: | |
print("---GRADE: DOCUMENT NOT RELEVANT---", score, d.metadata["url"]) | |
continue | |
num_eliminated_docs = len(documents) - len(filtered_docs) | |
if num_eliminated_docs > 0: | |
await cl_msg.stream_token( | |
f"Eliminated {num_eliminated_docs} recipes that were not relevant based on your criteria ... \n\n" | |
) | |
return {"documents": filtered_docs, "question": question} | |
async def _node_generate_response(state: AgentState, config): | |
""" | |
Determines whether the retrieved recipes are relevant to the question. | |
Args: | |
state (messages): The current state | |
Returns: | |
str: A decision for whether the documents are relevant or not | |
""" | |
print("--- GENERATING SHORTLIST ---") | |
question = state["question"] | |
documents = state["documents"] | |
# LLM with tool and validation | |
base_rag_prompt_template = """\ | |
You are a friendly AI assistant. Using the provided context, | |
please answer the user's question in a friendly, conversational tone. | |
Based on the context provided, please select the top 3 receipes that best fits criteria | |
outlined in the question. It doesn't need to be a perfect match but just get the most suitable. | |
For each option, provide the following information: | |
1. A brief description of the recipe | |
2. The URL of the recipe | |
3. The ratings and number of ratings | |
Only if question includes a criteria for recipes that are good for a specific occassion, please also provide the occassion(s) of the recipe, | |
Only if question includes a criteria a type of cuisine, please also provide the cuisines associated with the recipe. | |
Only if question includes a criteria a type of diet, please also provide the diet(s) associated with the recipe. | |
If the context is empty, please be careful to note to the user that there are no recipes matching those specific requirements and do NOT provide any other recipes as suggestions. | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
chain = base_rag_prompt | power_llm | |
full_response = "" | |
cl_msg = config["configurable"]["cl_msg"] | |
async for chunk in chain.astream( | |
{"question": question, "context": documents}, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
if isinstance(chunk, AIMessageChunk): | |
await cl_msg.stream_token(chunk.content) | |
full_response += chunk.content | |
url_extractor = get_recipe_url_extractor_chain(base_llm) | |
url_extractor_results = url_extractor.invoke({"context": full_response}) | |
shortlisted_recipes = None | |
if isinstance(url_extractor_results.urls, list) and len(url_extractor_results.urls): | |
shortlisted_recipes = get_recipes(url_extractor_results.urls) | |
return { | |
"documents": documents, | |
"question": question, | |
"shortlisted_recipes": shortlisted_recipes, | |
"messages": [full_response], | |
} | |
async def _node_shortlist_qa(state: AgentState, config): | |
print("--- Q&A with SHORTLISTED RECIPES ---") | |
question = state["question"] | |
shortlisted_recipes = state["shortlisted_recipes"] | |
messages = state["messages"] | |
last_message = messages[-1] if messages else "" | |
question_type = state["question_type"] | |
# LLM with tool and validation | |
base_rag_prompt_template = """\ | |
You are a friendly AI assistant. Using only the provided context, | |
please answer the user's question in a friendly, conversational tone. | |
If you don't know the answer based on the context, say you don't know. | |
Context: | |
{context} | |
Last message provided to the user: | |
{last_message} | |
Question: | |
{question} | |
""" | |
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
chain = base_rag_prompt | power_llm | |
full_response = "" | |
thumbnail_url = "" | |
cl_msg = config["configurable"]["cl_msg"] | |
if state["question_type"] == "show_specific_recipe": | |
selected_recipe = state.get("selected_recipe") | |
if selected_recipe and selected_recipe.get("thumbnail"): | |
thumbnail_url = selected_recipe["thumbnail"] | |
image = cl.Image(url=thumbnail_url, name="thumbnail", display="inline", size="large") | |
# Attach the image to the message | |
await cl.Message( | |
content="", | |
elements=[image], | |
).send() | |
async for chunk in chain.astream( | |
{ | |
"question": question, | |
"context": shortlisted_recipes_to_string(shortlisted_recipes), | |
"last_message": last_message, | |
}, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
if isinstance(chunk, AIMessageChunk): | |
await cl_msg.stream_token(chunk.content) | |
full_response += chunk.content | |
return { | |
"messages": [full_response], | |
} | |
async def _node_single_recipe_qa(state: AgentState, config): | |
print("--- Q&A with SINGLE RECIPE ---") | |
question = state["question"] | |
selected_recipe = state.get("selected_recipe") | |
messages = state["messages"] | |
last_message = messages[-1] if messages else "" | |
# LLM with tool and validation | |
base_rag_prompt_template = """\ | |
You are a friendly AI assistant. Using only the provided context, | |
please answer the user's question in a friendly, conversational tone. | |
If you don't know the answer based on the context, say you don't know. | |
Context: | |
{context} | |
Last message provided to the user: | |
{last_message} | |
Question: | |
{question} | |
""" | |
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
power_llm_with_tool = power_llm.bind_functions([convert_to_openai_function(send_text_tool)]) | |
chain = base_rag_prompt | power_llm_with_tool | |
full_response = "" | |
cl_msg = config["configurable"]["cl_msg"] | |
async for chunk in chain.astream( | |
{"question": question, "context": selected_recipe["text"], "last_message": last_message}, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
if isinstance(chunk, AIMessageChunk): | |
await cl_msg.stream_token(chunk.content) | |
full_response += chunk.content | |
return {"messages": [full_response]} | |
async def _node_send_sms(state: AgentState, config): | |
print("--- SEND SMS ---") | |
question = state["question"] | |
selected_recipe = state.get("selected_recipe") | |
messages = state["messages"] | |
last_message = messages[-1] if messages else "" | |
cl_msg = config["configurable"]["cl_msg"] | |
# LLM with tool and validation | |
base_rag_prompt_template = """\ | |
You are a friendly AI assistant. | |
Using only the provided context and the tool, | |
please fullfill the user's request to send an SMS text | |
Context: | |
{context} | |
Last message provided to the user: | |
{last_message} | |
Question: | |
{question} | |
""" | |
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
# tool_functions = | |
power_llm_with_tool = power_llm.bind_functions([convert_to_openai_function(send_text_tool)]) | |
chain = base_rag_prompt | power_llm_with_tool | |
tool_executor = ToolExecutor([send_text_tool]) | |
message = chain.invoke( | |
{ | |
"question": question, | |
"context": selected_recipe.get("text") if selected_recipe else "", | |
"last_message": last_message, | |
}, | |
) | |
print("message", message) | |
tool_arguments = json.loads(message.additional_kwargs["function_call"]["arguments"]) | |
action = ToolInvocation( | |
tool=message.additional_kwargs["function_call"]["name"], | |
tool_input=tool_arguments, | |
) | |
response = tool_executor.invoke(action) | |
function_message = FunctionMessage(content=str(response), name=action.tool) | |
await cl_msg.stream_token( | |
f"""Sure! I've sent a text to {tool_arguments['number']} with the following: \n\n{tool_arguments['text']}""" | |
) | |
return {"messages": [function_message]} | |
workflow = StateGraph(AgentState) | |
# Define the nodes | |
workflow.add_node("triage", _node_question_triage) # retrieve | |
workflow.add_node("retrieve", _node_call_retriever) # retrieve | |
workflow.add_node("grade_recipes", _node_grade_recipes) # grade documents | |
workflow.add_node("generate", _node_generate_response) # generatae | |
workflow.add_node("shortlist_qa", _node_shortlist_qa) # answer questions about shortlisted recipes | |
workflow.add_node("single_qa", _node_single_recipe_qa) # answer questions about shortlisted recipes | |
workflow.add_node("send_sms", _node_send_sms) # answer questions about shortlisted recipes | |
# Define the edges | |
def _edge_route_question(state: AgentState): | |
print("=======EDGE: START =====") | |
question_type = state["question_type"] | |
messages = state["messages"] | |
shortlisted_recipes = state.get("shortlisted_recipes") | |
selected_recipe = state.get("selected_recipe") | |
if question_type == "asking_for_recipe_suggestions": | |
return "retrieve" | |
if question_type in ["referring_to_shortlisted_recipes", "show_specific_recipe"]: | |
return "shortlist_qa" | |
if question_type == "referring_to_specific_recipe" and selected_recipe: | |
return "single_qa" | |
if question_type == "send_sms": | |
return "send_sms" | |
print("defaulting to shortlist_qa") | |
return "shortlist_qa" | |
workflow.add_edge(START, "triage") | |
workflow.add_conditional_edges( | |
"triage", | |
_edge_route_question, | |
{ | |
"shortlist_qa": "shortlist_qa", | |
"single_qa": "single_qa", | |
"retrieve": "retrieve", | |
"send_sms": "send_sms", | |
}, | |
) | |
workflow.add_edge("retrieve", "grade_recipes") | |
workflow.add_edge("grade_recipes", "generate") | |
workflow.add_edge("generate", END) | |
workflow.add_edge("shortlist_qa", END) | |
workflow.add_edge("single_qa", END) | |
workflow.add_edge("send_sms", END) | |
memory = AsyncSqliteSaver.from_conn_string(":memory:") | |
app = workflow.compile(checkpointer=memory) | |
return app | |