DrishtiSharma's picture
Update app.py
d77fbbb verified
raw
history blame
6.66 kB
import os
import chromadb
import streamlit as st
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
from langchain_core.documents import Document
from typing import Annotated, Sequence, TypedDict
import functools
import operator
from langchain_core.tools import tool
# Clear ChromaDB cache to fix tenant issue
chromadb.api.client.SharedSystemClient.clear_system_cache()
# Load environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
st.stop()
# Initialize API keys and LLM
llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
# Utility Functions
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
agent = create_openai_tools_agent(llm, tools, prompt)
return AgentExecutor(agent=agent, tools=tools)
def agent_node(state, agent, name):
result = agent.invoke(state)
return {"messages": [HumanMessage(content=result["output"], name=name)]}
@tool
def RAG(state):
"""Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results."""
st.session_state.outputs.append('-> Calling RAG ->')
question = state
template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
prompt = ChatPromptTemplate.from_template(template)
retrieval_chain = (
{"context": retriever, "question": RunnablePassthrough()} |
prompt |
llm |
StrOutputParser()
)
result = retrieval_chain.invoke(question)
return result
# Load Tools and Retriever
tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
python_repl_tool = PythonREPLTool()
# File Upload Section
st.title("Multi-Agent Workflow Demonstration")
uploaded_files = st.file_uploader("Upload your source files (TXT)", accept_multiple_files=True, type=['txt'])
if uploaded_files:
docs = []
for uploaded_file in uploaded_files:
content = uploaded_file.read().decode("utf-8")
docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
new_docs = text_splitter.split_documents(documents=docs)
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
db = Chroma.from_documents(new_docs, embeddings)
retriever = db.as_retriever(search_kwargs={"k": 4})
else:
retriever = None
st.warning("Please upload at least one text file to proceed.")
st.stop()
# Create Agents
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
members = ["RAG", "Researcher", "Coder"]
system_prompt = (
"You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
"Use RAG tool for Japan or Sports questions."
)
options = ["FINISH"] + members
function_def = {
"name": "route", "description": "Select the next role.",
"parameters": {
"title": "routeSchema", "type": "object",
"properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]
}
}
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
("system", "Given the conversation above, who should act next? Select one of: {options}"),
]).partial(options=str(options), members=", ".join(members))
supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
# Build Workflow
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: str
workflow = StateGraph(AgentState)
workflow.add_node("Researcher", research_node)
workflow.add_node("Coder", code_node)
workflow.add_node("RAG", rag_node)
workflow.add_node("supervisor", supervisor_chain)
for member in members:
workflow.add_edge(member, "supervisor")
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
workflow.set_entry_point("supervisor")
graph = workflow.compile()
# Streamlit UI
if 'outputs' not in st.session_state:
st.session_state.outputs = []
user_input = st.text_area("Enter your task or question:")
def run_workflow(task):
st.session_state.outputs.clear()
st.session_state.outputs.append(f"User Input: {task}")
for state in graph.stream({"messages": [HumanMessage(content=task)]}):
if "__end__" not in state:
st.session_state.outputs.append(str(state))
st.session_state.outputs.append("----")
if st.button("Run Workflow"):
if user_input:
run_workflow(user_input)
else:
st.warning("Please enter a task or question.")
st.subheader("Workflow Output:")
for output in st.session_state.outputs:
st.text(output)