Ask_TA / app.py
gokulp06's picture
Create app.py
176ef1a verified
from langchain.agents import ConversationalChatAgent, AgentExecutor
from langchain.memory import ConversationBufferMemory
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
import streamlit as st
st.set_page_config(page_title="Ask Your TA", page_icon="πŸ§‘β€πŸ”¬")
st.title("β€πŸ”¬ Ask Your TA")
openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
msgs = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(
chat_memory=msgs, return_messages=True, memory_key="chat_history", output_key="output"
)
if len(msgs.messages) == 0 or st.sidebar.button("Reset chat history"):
msgs.clear()
msgs.add_ai_message("How can I help you?")
st.session_state.steps = {}
avatars = {"human": "user", "ai": "assistant"}
for idx, msg in enumerate(msgs.messages):
with st.chat_message(avatars[msg.type]):
# Render intermediate steps if any were saved
for step in st.session_state.steps.get(str(idx), []):
if step[0].tool == "_Exception":
continue
with st.status(f"**{step[0].tool}**: {step[0].tool_input}", state="complete"):
st.write(step[0].log)
st.write(step[1])
st.write(msg.content)
if prompt := st.chat_input(placeholder="Who won the Women's U.S. Open in 2018?"):
st.chat_message("user").write(prompt)
if not openai_api_key:
st.info("Please add your OpenAI API key to continue.")
st.stop()
llm = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key, streaming=True)
tools = [DuckDuckGoSearchRun(name="Search")]
chat_agent = ConversationalChatAgent.from_llm_and_tools(llm=llm, tools=tools)
executor = AgentExecutor.from_agent_and_tools(
agent=chat_agent,
tools=tools,
memory=memory,
return_intermediate_steps=True,
handle_parsing_errors=True,
)
with st.chat_message("assistant"):
st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
cfg = RunnableConfig()
cfg["callbacks"] = [st_cb]
response = executor.invoke(prompt, cfg)
st.write(response["output"])
st.session_state.steps[str(len(msgs.messages) - 1)] = response["intermediate_steps"]