import os import io import chromadb import streamlit as st import matplotlib.pyplot as plt from langchain_openai import ChatOpenAI from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain_core.messages import BaseMessage, HumanMessage from langchain_community.tools.tavily_search import TavilySearchResults from langchain_experimental.tools import PythonREPLTool from langchain_community.document_loaders import DirectoryLoader, TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import StateGraph, END from langchain_core.documents import Document from typing import Annotated, Sequence, TypedDict import functools import operator from langchain_core.tools import tool from glob import glob # Clear ChromaDB cache chromadb.api.client.SharedSystemClient.clear_system_cache() # Load environment variables OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") if not OPENAI_API_KEY or not TAVILY_API_KEY: st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.") st.stop() # Initialize LLM llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) # Utility Functions def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="agent_scratchpad"), ]) agent = create_openai_tools_agent(llm, tools, prompt) return AgentExecutor(agent=agent, tools=tools) def agent_node(state, agent, name): # Run the agent and get its output result = agent.invoke(state) output_content = result["output"] # Check if Python code generates a graph if "matplotlib" in output_content or "plt." in output_content: exec_locals = {} try: exec(output_content, {}, exec_locals) fig = plt.gcf() buf = io.BytesIO() fig.savefig(buf, format="png") buf.seek(0) st.session_state.graph_image = buf except Exception as e: output_content += f"\nError: {str(e)}" return {"messages": [HumanMessage(content=output_content, name=name)]} @tool def RAG(state): """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results.""" st.session_state.outputs.append('-> Calling RAG ->') question = state template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}""" prompt = ChatPromptTemplate.from_template(template) retrieval_chain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) result = retrieval_chain.invoke(question) return result # Tools Setup tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY) python_repl_tool = PythonREPLTool() # Streamlit UI # Sidebar with References st.sidebar.title("References") st.sidebar.markdown("1. [Multi-Agent with Supervisor](https://github.com/aritrasen87/LLM_RAG_Model_Deployment/blob/main/LangGraph_03_MultiAgent_With_Supervisor.ipynb)") st.title("Multi-Agent with Supervisor") example_questions = [ "What is James McIlroy aiming for in sports?", "Fetch India's GDP over the past 5 years and draw a line graph.", "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph." ] source_files = glob("sources/*.txt") selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2]) uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt']) # Document Handling all_docs = [] if selected_files: for file_path in selected_files: loader = TextLoader(file_path) all_docs.extend(loader.load()) if uploaded_files: for uploaded_file in uploaded_files: content = uploaded_file.read().decode("utf-8") all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name})) if not all_docs: st.warning("Please select files or upload TXT files.") st.stop() # Document Splitting and Embedding text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10) split_docs = text_splitter.split_documents(all_docs) embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}) db = Chroma.from_documents(split_docs, embeddings) retriever = db.as_retriever(search_kwargs={"k": 4}) # Agents research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.") RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.") research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") code_node = functools.partial(agent_node, agent=code_agent, name="Coder") rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG") members = ["RAG", "Researcher", "Coder"] system_prompt = "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH." options = ["FINISH"] + members function_def = { "name": "route", "description": "Select the next role.", "parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]} } prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ("system", "Given the conversation above, who should act next? Select one of: {options}"), ]).partial(options=str(options), members=", ".join(members)) supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser()) class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] next: str workflow = StateGraph(AgentState) workflow.add_node("Researcher", research_node) workflow.add_node("Coder", code_node) workflow.add_node("RAG", rag_node) workflow.add_node("supervisor", supervisor_chain) for member in members: workflow.add_edge(member, "supervisor") conditional_map = {k: k for k in members} conditional_map["FINISH"] = END workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) workflow.set_entry_point("supervisor") graph = workflow.compile() # Workflow Execution if 'outputs' not in st.session_state: st.session_state.outputs = [] user_input = st.text_area("Enter your task or question:", value=example_questions[0]) def run_workflow(task): st.session_state.outputs.clear() st.session_state.outputs.append(f"User Input: {task}") st.session_state.graph_image = None for state in graph.stream({"messages": [HumanMessage(content=task)]}): if "__end__" not in state: st.session_state.outputs.append(str(state)) st.session_state.outputs.append("----") if st.button("Run Workflow"): if user_input: run_workflow(user_input) else: st.warning("Please enter a task or question.") st.subheader("Workflow Output:") for output in st.session_state.outputs: st.text(output) if "graph_image" in st.session_state and st.session_state.graph_image: st.subheader("Generated Graph:") st.image(st.session_state.graph_image, caption="Generated Line Graph", use_column_width=True)