Joshua Sundance Bailey commited on
Commit
827bf89
1 Parent(s): b0a3c25
langchain-streamlit-demo/app.py CHANGED
@@ -9,7 +9,7 @@ from langchain.schema.runnable import RunnableConfig
9
  from langsmith.client import Client
10
 
11
  from llm_stuff import (
12
- _MEMORY,
13
  _MODEL_DICT,
14
  _SUPPORTED_MODELS,
15
  _DEFAULT_MODEL,
@@ -37,6 +37,14 @@ if "trace_link" not in st.session_state:
37
  st.session_state.trace_link = None
38
  if "run_id" not in st.session_state:
39
  st.session_state.run_id = None
 
 
 
 
 
 
 
 
40
 
41
  model = st.sidebar.selectbox(
42
  label="Chat Model",
@@ -76,6 +84,12 @@ system_prompt = (
76
  .replace("}", "}}")
77
  )
78
 
 
 
 
 
 
 
79
  temperature = st.sidebar.slider(
80
  "Temperature",
81
  min_value=_MIN_TEMPERATURE,
@@ -103,19 +117,6 @@ if provider_api_key:
103
 
104
  run_collector = RunCollectorCallbackHandler()
105
 
106
- if st.sidebar.button("Clear message history"):
107
- print("Clearing message history")
108
- st.session_state["langchain_messages"].memory.clear()
109
- st.session_state.trace_link = None
110
- st.session_state.run_id = None
111
-
112
- for msg in _MEMORY.messages:
113
- with st.chat_message(
114
- msg.type,
115
- avatar="🦜" if msg.type in ("ai", "assistant") else None,
116
- ):
117
- st.markdown(msg.content)
118
-
119
 
120
  def _reset_feedback():
121
  st.session_state.feedback_update = None
 
9
  from langsmith.client import Client
10
 
11
  from llm_stuff import (
12
+ _STMEMORY,
13
  _MODEL_DICT,
14
  _SUPPORTED_MODELS,
15
  _DEFAULT_MODEL,
 
37
  st.session_state.trace_link = None
38
  if "run_id" not in st.session_state:
39
  st.session_state.run_id = None
40
+ if len(_STMEMORY.messages) == 0:
41
+ _STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
42
+
43
+ for msg in _STMEMORY.messages:
44
+ st.chat_message(
45
+ msg.type,
46
+ avatar="🦜" if msg.type in ("ai", "assistant") else None,
47
+ ).write(msg.content)
48
 
49
  model = st.sidebar.selectbox(
50
  label="Chat Model",
 
84
  .replace("}", "}}")
85
  )
86
 
87
+ if st.sidebar.button("Clear message history"):
88
+ print("Clearing message history")
89
+ _STMEMORY.clear()
90
+ st.session_state.trace_link = None
91
+ st.session_state.run_id = None
92
+
93
  temperature = st.sidebar.slider(
94
  "Temperature",
95
  min_value=_MIN_TEMPERATURE,
 
117
 
118
  run_collector = RunCollectorCallbackHandler()
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def _reset_feedback():
122
  st.session_state.feedback_update = None
langchain-streamlit-demo/llm_stuff.py CHANGED
@@ -6,11 +6,15 @@ from langchain.callbacks.base import BaseCallbackHandler
6
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
7
  from langchain.chat_models.base import BaseChatModel
8
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
9
- from langchain.memory.chat_memory import BaseChatMemory
10
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
11
  from streamlit_feedback import streamlit_feedback
12
 
13
- _MEMORY = StreamlitChatMessageHistory(key="langchain_messages")
 
 
 
 
 
14
 
15
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
16
 
@@ -35,14 +39,6 @@ _MIN_TOKENS = 1
35
  _MAX_TOKENS = 100000
36
 
37
 
38
- def get_memory() -> BaseChatMemory:
39
- return ConversationBufferMemory(
40
- chat_memory=_MEMORY,
41
- return_messages=True,
42
- memory_key="chat_history",
43
- )
44
-
45
-
46
  def get_llm(
47
  model: str,
48
  provider_api_key: str,
@@ -95,9 +91,8 @@ def get_llm_chain(
95
  ("human", "{input}"),
96
  ],
97
  ).partial(time=lambda: str(datetime.now()))
98
- memory = get_memory()
99
  llm = get_llm(model, provider_api_key, temperature, max_tokens)
100
- return LLMChain(prompt=prompt, llm=llm, memory=memory or get_memory())
101
 
102
 
103
  class StreamHandler(BaseCallbackHandler):
 
6
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
7
  from langchain.chat_models.base import BaseChatModel
8
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
 
9
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
10
  from streamlit_feedback import streamlit_feedback
11
 
12
+ _STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
13
+ _MEMORY = ConversationBufferMemory(
14
+ chat_memory=_STMEMORY,
15
+ return_messages=True,
16
+ memory_key="chat_history",
17
+ )
18
 
19
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
20
 
 
39
  _MAX_TOKENS = 100000
40
 
41
 
 
 
 
 
 
 
 
 
42
  def get_llm(
43
  model: str,
44
  provider_api_key: str,
 
91
  ("human", "{input}"),
92
  ],
93
  ).partial(time=lambda: str(datetime.now()))
 
94
  llm = get_llm(model, provider_api_key, temperature, max_tokens)
95
+ return LLMChain(prompt=prompt, llm=llm, memory=_MEMORY)
96
 
97
 
98
  class StreamHandler(BaseCallbackHandler):