DrishtiSharma commited on
Commit
c33ab38
·
verified ·
1 Parent(s): 96ee009

Create interim_v2.py

Browse files
Files changed (1) hide show
  1. interim_v2.py +208 -0
interim_v2.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chromadb
3
+ import streamlit as st
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain.agents import AgentExecutor, create_openai_tools_agent
6
+ from langchain_core.messages import BaseMessage, HumanMessage
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_experimental.tools import PythonREPLTool
9
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import Chroma
12
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain_core.output_parsers import StrOutputParser
14
+ from langchain_core.runnables import RunnablePassthrough
15
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
16
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
+ from langgraph.graph import StateGraph, END
18
+ from langchain_core.documents import Document
19
+ from typing import Annotated, Sequence, TypedDict
20
+ import functools
21
+ import operator
22
+ from langchain_core.tools import tool
23
+ from glob import glob
24
+
25
+
26
+ # Clear ChromaDB cache to fix tenant issue
27
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
28
+
29
+ # Load environment variables
30
+
31
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
32
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
33
+
34
+ if not OPENAI_API_KEY or not TAVILY_API_KEY:
35
+ st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.")
36
+ st.stop()
37
+
38
+ # Initialize API keys and LLM
39
+ llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY)
40
+
41
+ # Utility Functions
42
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
43
+ prompt = ChatPromptTemplate.from_messages([
44
+ ("system", system_prompt),
45
+ MessagesPlaceholder(variable_name="messages"),
46
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
47
+ ])
48
+ agent = create_openai_tools_agent(llm, tools, prompt)
49
+ return AgentExecutor(agent=agent, tools=tools)
50
+
51
+ def agent_node(state, agent, name):
52
+ # Run the agent and get its output
53
+ result = agent.invoke(state)
54
+ output_content = result["output"]
55
+
56
+ # Check if the output contains Python code that generates a graph
57
+ if "matplotlib" in output_content or "plt." in output_content:
58
+ exec_locals = {}
59
+ try:
60
+ exec(output_content, {}, exec_locals) # Safely execute the code
61
+ fig = plt.gcf() # Get the current matplotlib figure
62
+
63
+ # Save the figure to a buffer
64
+ buf = io.BytesIO()
65
+ fig.savefig(buf, format="png")
66
+ buf.seek(0)
67
+
68
+ # Add image to session state for display
69
+ st.session_state.graph_image = buf
70
+ except Exception as e:
71
+ output_content += f"\nError: {str(e)}"
72
+
73
+ return {"messages": [HumanMessage(content=output_content, name=name)]}
74
+
75
+ @tool
76
+ def RAG(state):
77
+ """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results."""
78
+ st.session_state.outputs.append('-> Calling RAG ->')
79
+ question = state
80
+ template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}"""
81
+ prompt = ChatPromptTemplate.from_template(template)
82
+ retrieval_chain = (
83
+ {"context": retriever, "question": RunnablePassthrough()} |
84
+ prompt |
85
+ llm |
86
+ StrOutputParser()
87
+ )
88
+ result = retrieval_chain.invoke(question)
89
+ return result
90
+
91
+ # Load Tools
92
+ tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY)
93
+ python_repl_tool = PythonREPLTool()
94
+
95
+ # Streamlit UI
96
+ st.title("Multi-Agent w Supervisor")
97
+
98
+ # Example questions for immediate testing
99
+ example_questions = [
100
+ #"Code hello world and print it",
101
+ "What is James McIlroy aiming for in sports?",
102
+ "Fetch India's GDP over the past 5 years and draw a line graph.",
103
+ "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph."
104
+ ]
105
+
106
+ # File Selection Section
107
+ source_files = glob("sources/*.txt")
108
+ selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2])
109
+
110
+ uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt'])
111
+
112
+ # Combine Files
113
+ all_docs = []
114
+ if selected_files:
115
+ for file_path in selected_files:
116
+ loader = TextLoader(file_path)
117
+ all_docs.extend(loader.load())
118
+
119
+ if uploaded_files:
120
+ for uploaded_file in uploaded_files:
121
+ content = uploaded_file.read().decode("utf-8")
122
+ all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name}))
123
+
124
+ if not all_docs:
125
+ st.warning("Please select files from the source directory or upload TXT files.")
126
+ st.stop()
127
+
128
+ # Process Documents
129
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len)
130
+ split_docs = text_splitter.split_documents(all_docs)
131
+
132
+ embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True})
133
+ db = Chroma.from_documents(split_docs, embeddings)
134
+ retriever = db.as_retriever(search_kwargs={"k": 4})
135
+
136
+ # Create Agents
137
+ research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
138
+ code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.")
139
+ RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.")
140
+
141
+ research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
142
+ code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
143
+ rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG")
144
+
145
+ members = ["RAG", "Researcher", "Coder"]
146
+ system_prompt = (
147
+ "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. "
148
+ "Use RAG tool for Japan or Sports questions."
149
+ )
150
+ options = ["FINISH"] + members
151
+ function_def = {
152
+ "name": "route", "description": "Select the next role.",
153
+ "parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]}
154
+ }
155
+ prompt = ChatPromptTemplate.from_messages([
156
+ ("system", system_prompt),
157
+ MessagesPlaceholder(variable_name="messages"),
158
+ ("system", "Given the conversation above, who should act next? Select one of: {options}"),
159
+ ]).partial(options=str(options), members=", ".join(members))
160
+
161
+ supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser())
162
+
163
+ # Workflow
164
+ class AgentState(TypedDict):
165
+ messages: Annotated[Sequence[BaseMessage], operator.add]
166
+ next: str
167
+
168
+ workflow = StateGraph(AgentState)
169
+ workflow.add_node("Researcher", research_node)
170
+ workflow.add_node("Coder", code_node)
171
+ workflow.add_node("RAG", rag_node)
172
+ workflow.add_node("supervisor", supervisor_chain)
173
+
174
+ for member in members:
175
+ workflow.add_edge(member, "supervisor")
176
+ conditional_map = {k: k for k in members}
177
+ conditional_map["FINISH"] = END
178
+ workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
179
+ workflow.set_entry_point("supervisor")
180
+ graph = workflow.compile()
181
+
182
+ # Workflow Execution
183
+ if 'outputs' not in st.session_state:
184
+ st.session_state.outputs = []
185
+
186
+ user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0])
187
+
188
+ def run_workflow(task):
189
+ st.session_state.outputs.clear()
190
+ st.session_state.outputs.append(f"User Input: {task}")
191
+ for state in graph.stream({"messages": [HumanMessage(content=task)]}):
192
+ if "__end__" not in state:
193
+ st.session_state.outputs.append(str(state))
194
+ st.session_state.outputs.append("----")
195
+
196
+ if st.button("Run Workflow"):
197
+ if user_input:
198
+ run_workflow(user_input)
199
+ else:
200
+ st.warning("Please enter a task or question.")
201
+
202
+ st.subheader("Example Questions:")
203
+ for example in example_questions:
204
+ st.text(f"- {example}")
205
+
206
+ st.subheader("Workflow Output:")
207
+ for output in st.session_state.outputs:
208
+ st.text(output)