import os import chromadb import streamlit as st from dotenv import load_dotenv from langchain_openai import ChatOpenAI from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain_core.messages import BaseMessage, HumanMessage from langchain_community.tools.tavily_search import TavilySearchResults from langchain_experimental.tools import PythonREPLTool from langchain_community.document_loaders import DirectoryLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import StateGraph, END from langchain_core.documents import Document from typing import Annotated, Sequence, TypedDict import functools import operator from langchain_core.tools import tool # Clear ChromaDB cache to fix tenant issue chromadb.api.client.SharedSystemClient.clear_system_cache() # Load environment variables OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") if not OPENAI_API_KEY or not TAVILY_API_KEY: st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.") st.stop() # Initialize API keys and LLM llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) # Utility Functions def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), MessagesPlaceholder(variable_name="agent_scratchpad"), ]) agent = create_openai_tools_agent(llm, tools, prompt) return AgentExecutor(agent=agent, tools=tools) def agent_node(state, agent, name): result = agent.invoke(state) return {"messages": [HumanMessage(content=result["output"], name=name)]} @tool def RAG(state): """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results.""" st.session_state.outputs.append('-> Calling RAG ->') question = state template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}""" prompt = ChatPromptTemplate.from_template(template) retrieval_chain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) result = retrieval_chain.invoke(question) return result # Load Tools and Retriever tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY) python_repl_tool = PythonREPLTool() # File Upload Section st.title("Multi-Agent Workflow Demonstration") uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt']) if uploaded_files: docs = [] for uploaded_file in uploaded_files: content = uploaded_file.read().decode("utf-8") docs.append(Document(page_content=content, metadata={"name": uploaded_file.name})) text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len) new_docs = text_splitter.split_documents(documents=docs) embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}) db = Chroma.from_documents(new_docs, embeddings) retriever = db.as_retriever(search_kwargs={"k": 4}) else: retriever = None st.warning("Please upload at least one text file to proceed.") st.stop() # Create Agents research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.") RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.") research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") code_node = functools.partial(agent_node, agent=code_agent, name="Coder") rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG") members = ["RAG", "Researcher", "Coder"] system_prompt = ( "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. " "Use RAG tool for Japan or Sports questions." ) options = ["FINISH"] + members function_def = { "name": "route", "description": "Select the next role.", "parameters": { "title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"] } } prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), MessagesPlaceholder(variable_name="messages"), ("system", "Given the conversation above, who should act next? Select one of: {options}"), ]).partial(options=str(options), members=", ".join(members)) supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser()) # Build Workflow class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] next: str workflow = StateGraph(AgentState) workflow.add_node("Researcher", research_node) workflow.add_node("Coder", code_node) workflow.add_node("RAG", rag_node) workflow.add_node("supervisor", supervisor_chain) for member in members: workflow.add_edge(member, "supervisor") conditional_map = {k: k for k in members} conditional_map["FINISH"] = END workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) workflow.set_entry_point("supervisor") graph = workflow.compile() # Streamlit UI if 'outputs' not in st.session_state: st.session_state.outputs = [] user_input = st.text_area("Enter your task or question:") def run_workflow(task): st.session_state.outputs.clear() st.session_state.outputs.append(f"User Input: {task}") for state in graph.stream({"messages": [HumanMessage(content=task)]}): if "__end__" not in state: st.session_state.outputs.append(str(state)) st.session_state.outputs.append("----") if st.button("Run Workflow"): if user_input: run_workflow(user_input) else: st.warning("Please enter a task or question.") st.subheader("Workflow Output:") for output in st.session_state.outputs: st.text(output)