|
import os |
|
import chromadb |
|
import streamlit as st |
|
from langchain_openai import ChatOpenAI |
|
from langchain.agents import AgentExecutor, create_openai_tools_agent |
|
from langchain_core.messages import BaseMessage, HumanMessage |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_experimental.tools import PythonREPLTool |
|
from langchain_community.document_loaders import DirectoryLoader, TextLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langgraph.graph import StateGraph, END |
|
from langchain_core.documents import Document |
|
from typing import Annotated, Sequence, TypedDict |
|
import functools |
|
import operator |
|
from langchain_core.tools import tool |
|
from glob import glob |
|
|
|
|
|
|
|
chromadb.api.client.SharedSystemClient.clear_system_cache() |
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
if not OPENAI_API_KEY or not TAVILY_API_KEY: |
|
st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.") |
|
st.stop() |
|
|
|
|
|
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) |
|
|
|
|
|
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_prompt), |
|
MessagesPlaceholder(variable_name="messages"), |
|
MessagesPlaceholder(variable_name="agent_scratchpad"), |
|
]) |
|
agent = create_openai_tools_agent(llm, tools, prompt) |
|
return AgentExecutor(agent=agent, tools=tools) |
|
|
|
def agent_node(state, agent, name): |
|
|
|
result = agent.invoke(state) |
|
output_content = result["output"] |
|
|
|
|
|
if "matplotlib" in output_content or "plt." in output_content: |
|
exec_locals = {} |
|
try: |
|
exec(output_content, {}, exec_locals) |
|
fig = plt.gcf() |
|
|
|
|
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png") |
|
buf.seek(0) |
|
|
|
|
|
st.session_state.graph_image = buf |
|
except Exception as e: |
|
output_content += f"\nError: {str(e)}" |
|
|
|
return {"messages": [HumanMessage(content=output_content, name=name)]} |
|
|
|
@tool |
|
def RAG(state): |
|
"""Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results.""" |
|
st.session_state.outputs.append('-> Calling RAG ->') |
|
question = state |
|
template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}""" |
|
prompt = ChatPromptTemplate.from_template(template) |
|
retrieval_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} | |
|
prompt | |
|
llm | |
|
StrOutputParser() |
|
) |
|
result = retrieval_chain.invoke(question) |
|
return result |
|
|
|
|
|
tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY) |
|
python_repl_tool = PythonREPLTool() |
|
|
|
|
|
st.title("Multi-Agent w Supervisor") |
|
|
|
|
|
example_questions = [ |
|
|
|
"What is James McIlroy aiming for in sports?", |
|
"Fetch India's GDP over the past 5 years and draw a line graph.", |
|
"Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph." |
|
] |
|
|
|
|
|
source_files = glob("sources/*.txt") |
|
selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2]) |
|
|
|
uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt']) |
|
|
|
|
|
all_docs = [] |
|
if selected_files: |
|
for file_path in selected_files: |
|
loader = TextLoader(file_path) |
|
all_docs.extend(loader.load()) |
|
|
|
if uploaded_files: |
|
for uploaded_file in uploaded_files: |
|
content = uploaded_file.read().decode("utf-8") |
|
all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name})) |
|
|
|
if not all_docs: |
|
st.warning("Please select files from the source directory or upload TXT files.") |
|
st.stop() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len) |
|
split_docs = text_splitter.split_documents(all_docs) |
|
|
|
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}) |
|
db = Chroma.from_documents(split_docs, embeddings) |
|
retriever = db.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") |
|
code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.") |
|
RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.") |
|
|
|
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") |
|
code_node = functools.partial(agent_node, agent=code_agent, name="Coder") |
|
rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG") |
|
|
|
members = ["RAG", "Researcher", "Coder"] |
|
system_prompt = ( |
|
"You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. " |
|
"Use RAG tool for Japan or Sports questions." |
|
) |
|
options = ["FINISH"] + members |
|
function_def = { |
|
"name": "route", "description": "Select the next role.", |
|
"parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]} |
|
} |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", system_prompt), |
|
MessagesPlaceholder(variable_name="messages"), |
|
("system", "Given the conversation above, who should act next? Select one of: {options}"), |
|
]).partial(options=str(options), members=", ".join(members)) |
|
|
|
supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser()) |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
next: str |
|
|
|
workflow = StateGraph(AgentState) |
|
workflow.add_node("Researcher", research_node) |
|
workflow.add_node("Coder", code_node) |
|
workflow.add_node("RAG", rag_node) |
|
workflow.add_node("supervisor", supervisor_chain) |
|
|
|
for member in members: |
|
workflow.add_edge(member, "supervisor") |
|
conditional_map = {k: k for k in members} |
|
conditional_map["FINISH"] = END |
|
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) |
|
workflow.set_entry_point("supervisor") |
|
graph = workflow.compile() |
|
|
|
|
|
if 'outputs' not in st.session_state: |
|
st.session_state.outputs = [] |
|
|
|
user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0]) |
|
|
|
def run_workflow(task): |
|
st.session_state.outputs.clear() |
|
st.session_state.outputs.append(f"User Input: {task}") |
|
for state in graph.stream({"messages": [HumanMessage(content=task)]}): |
|
if "__end__" not in state: |
|
st.session_state.outputs.append(str(state)) |
|
st.session_state.outputs.append("----") |
|
|
|
if st.button("Run Workflow"): |
|
if user_input: |
|
run_workflow(user_input) |
|
else: |
|
st.warning("Please enter a task or question.") |
|
|
|
st.subheader("Example Questions:") |
|
for example in example_questions: |
|
st.text(f"- {example}") |
|
|
|
st.subheader("Workflow Output:") |
|
for output in st.session_state.outputs: |
|
st.text(output) |