DrishtiSharma's picture
Create interim_v2.py
c33ab38 verified
import os
import chromadb
import streamlit as st
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from typing import Annotated, Sequence, TypedDict
import functools
import operator
from langchain_core.tools import tool
from glob import glob
# Clear ChromaDB cache to fix tenant issue
chromadb.api.client.SharedSystemClient.clear_system_cache()
# Load environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
st.stop()
# Initialize API keys and LLM
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
# Utility Functions
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_openai_tools_agent(llm, tools, prompt)
return AgentExecutor(agent=agent, tools=tools)
def agent_node(state, agent, name):
# Run the agent and get its output
result = agent.invoke(state)
output_content = result["output"]
# Check if the output contains Python code that generates a graph
if "matplotlib" in output_content or "plt." in output_content:
exec_locals = {}
try:
exec(output_content, {}, exec_locals) # Safely execute the code
fig = plt.gcf() # Get the current matplotlib figure
# Save the figure to a buffer
buf = io.BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
# Add image to session state for display
st.session_state.graph_image = buf
except Exception as e:
output_content += f"\nError: {str(e)}"
return {"messages": [HumanMessage(content=output_content, name=name)]}
@tool
def RAG(state):
"""Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results."""
st.session_state.outputs.append('-> Calling RAG ->')
question = state
template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
prompt = ChatPromptTemplate.from_template(template)
retrieval_chain = (
{"context": retriever, "question": RunnablePassthrough()} |
prompt |
llm |
StrOutputParser()
)
result = retrieval_chain.invoke(question)
return result
# Load Tools
tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
python_repl_tool = PythonREPLTool()
# Streamlit UI
st.title("Multi-Agent w Supervisor")
# Example questions for immediate testing
example_questions = [
#"Code hello world and print it",
"What is James McIlroy aiming for in sports?",
"Fetch India's GDP over the past 5 years and draw a line graph.",
"Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
]
# File Selection Section
source_files = glob("sources/*.txt")
selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
# Combine Files
all_docs = []
if selected_files:
for file_path in selected_files:
loader = TextLoader(file_path)
all_docs.extend(loader.load())
if uploaded_files:
for uploaded_file in uploaded_files:
content = uploaded_file.read().decode("utf-8")
all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
if not all_docs:
st.warning("Please select files from the source directory or upload TXT files.")
st.stop()
# Process Documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
split_docs = text_splitter.split_documents(all_docs)
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
db = Chroma.from_documents(split_docs, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 4})
# Create Agents
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
members = ["RAG", "Researcher", "Coder"]
system_prompt = (
"You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
"Use RAG tool for Japan or Sports questions."
)
options = ["FINISH"] + members
function_def = {
"name": "route", "description": "Select the next role.",
"parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]}
}
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
("system", "Given the conversation above, who should act next? Select one of: {options}"),
]).partial(options=str(options), members=", ".join(members))
supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
# Workflow
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: str
workflow = StateGraph(AgentState)
workflow.add_node("Researcher", research_node)
workflow.add_node("Coder", code_node)
workflow.add_node("RAG", rag_node)
workflow.add_node("supervisor", supervisor_chain)
for member in members:
workflow.add_edge(member, "supervisor")
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
workflow.set_entry_point("supervisor")
graph = workflow.compile()
# Workflow Execution
if 'outputs' not in st.session_state:
st.session_state.outputs = []
user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0])
def run_workflow(task):
st.session_state.outputs.clear()
st.session_state.outputs.append(f"User Input: {task}")
for state in graph.stream({"messages": [HumanMessage(content=task)]}):
if "__end__" not in state:
st.session_state.outputs.append(str(state))
st.session_state.outputs.append("----")
if st.button("Run Workflow"):
if user_input:
run_workflow(user_input)
else:
st.warning("Please enter a task or question.")
st.subheader("Example Questions:")
for example in example_questions:
st.text(f"- {example}")
st.subheader("Workflow Output:")
for output in st.session_state.outputs:
st.text(output)