import sys import os from contextlib import contextmanager from langchain.schema import Document from langgraph.graph import END, StateGraph from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod from typing_extensions import TypedDict from typing import List, Dict from IPython.display import display, HTML, Image from .chains.answer_chitchat import make_chitchat_node from .chains.answer_ai_impact import make_ai_impact_node from .chains.query_transformation import make_query_transform_node from .chains.translation import make_translation_node from .chains.intent_categorization import make_intent_categorization_node from .chains.retrieve_documents import make_retriever_node from .chains.answer_rag import make_rag_node from .chains.graph_retriever import make_graph_retriever_node from .chains.chitchat_categorization import make_chitchat_intent_categorization_node # from .chains.set_defaults import set_defaults class GraphState(TypedDict): """ Represents the state of our graph. """ user_input : str language : str intent : str search_graphs_chitchat : bool query: str remaining_questions : List[dict] n_questions : int answer: str audience: str = "experts" sources_input: List[str] = ["IPCC","IPBES"] relevant_content_sources: List[str] = ["IPCC figures"] sources_auto: bool = True min_year: int = 1960 max_year: int = None documents: List[Document] related_contents : Dict[str,Document] recommended_content : List[Document] search_only : bool = False def search(state): #TODO return state def answer_search(state):#TODO return state def route_intent(state): intent = state["intent"] if intent in ["chitchat","esg"]: return "answer_chitchat" # elif intent == "ai_impact": # return "answer_ai_impact" else: # Search route return "search" def chitchat_route_intent(state): intent = state["search_graphs_chitchat"] if intent is True: return "retrieve_graphs_chitchat" elif intent is False: return END def route_translation(state): if state["language"].lower() == "english": return "transform_query" else: return "translate_query" def route_based_on_relevant_docs(state,threshold_docs=0.2): docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs] if len(docs) > 0: return "answer_rag" else: return "answer_rag_no_docs" def route_retrieve_documents(state): if state["search_only"] : return END elif len(state["remaining_questions"]) > 0: return "retrieve_documents" else: return "answer_search" def make_id_dict(values): return {k:k for k in values} def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2): workflow = StateGraph(GraphState) # Define the node functions categorize_intent = make_intent_categorization_node(llm) transform_query = make_query_transform_node(llm) translate_query = make_translation_node(llm) answer_chitchat = make_chitchat_node(llm) answer_ai_impact = make_ai_impact_node(llm) retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm) retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker) answer_rag = make_rag_node(llm, with_docs=True) answer_rag_no_docs = make_rag_node(llm, with_docs=False) chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm) # Define the nodes # workflow.add_node("set_defaults", set_defaults) workflow.add_node("categorize_intent", categorize_intent) workflow.add_node("search", search) workflow.add_node("answer_search", answer_search) workflow.add_node("transform_query", transform_query) workflow.add_node("translate_query", translate_query) workflow.add_node("answer_chitchat", answer_chitchat) workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent) workflow.add_node("retrieve_graphs", retrieve_graphs) workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs) workflow.add_node("retrieve_documents", retrieve_documents) workflow.add_node("answer_rag", answer_rag) workflow.add_node("answer_rag_no_docs", answer_rag_no_docs) # Entry point workflow.set_entry_point("categorize_intent") # CONDITIONAL EDGES workflow.add_conditional_edges( "categorize_intent", route_intent, make_id_dict(["answer_chitchat","search"]) ) workflow.add_conditional_edges( "chitchat_categorize_intent", chitchat_route_intent, make_id_dict(["retrieve_graphs_chitchat", END]) ) workflow.add_conditional_edges( "search", route_translation, make_id_dict(["translate_query","transform_query"]) ) workflow.add_conditional_edges( "retrieve_documents", # lambda state : "retrieve_documents" if len(state["remaining_questions"]) > 0 else "answer_search", route_retrieve_documents, make_id_dict([END,"retrieve_documents","answer_search"]) ) workflow.add_conditional_edges( "answer_search", lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs), make_id_dict(["answer_rag","answer_rag_no_docs"]) ) workflow.add_conditional_edges( "transform_query", lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END, make_id_dict(["retrieve_graphs", END]) ) # Define the edges workflow.add_edge("translate_query", "transform_query") workflow.add_edge("transform_query", "retrieve_documents") workflow.add_edge("retrieve_graphs", END) workflow.add_edge("answer_rag", END) workflow.add_edge("answer_rag_no_docs", END) workflow.add_edge("answer_chitchat", "chitchat_categorize_intent") # Compile app = workflow.compile() return app def display_graph(app): display( Image( app.get_graph(xray = True).draw_mermaid_png( draw_method=MermaidDrawMethod.API, ) ) )