Joshua Sundance Bailey commited on
Commit
b0a3c25
β€’
2 Parent(s): db5e26f 60ba98a

Merge pull request #11 from joshuasundance-swca/dev

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -3,13 +3,13 @@ import os
3
  import anthropic
4
  import openai
5
  import streamlit as st
6
- from langchain.callbacks.manager import tracing_v2_enabled
7
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
8
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
9
  from langchain.schema.runnable import RunnableConfig
10
  from langsmith.client import Client
11
 
12
  from llm_stuff import (
 
13
  _MODEL_DICT,
14
  _SUPPORTED_MODELS,
15
  _DEFAULT_MODEL,
@@ -45,52 +45,54 @@ model = st.sidebar.selectbox(
45
  )
46
  provider = _MODEL_DICT[model]
47
 
48
- if provider_api_key := st.sidebar.text_input(f"{provider} API key", type="password"):
49
- if langsmith_api_key := st.sidebar.text_input(
50
- "LangSmith API Key (optional)",
51
- type="password",
52
- ):
53
- langsmith_project = st.sidebar.text_input(
54
- "LangSmith Project Name",
55
- value="langchain-streamlit-demo",
56
- )
57
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
58
- os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
59
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
60
- os.environ["LANGCHAIN_PROJECT"] = langsmith_project
61
-
62
- client = Client(api_key=langsmith_api_key)
63
- else:
64
- langsmith_project = None
65
- client = None
66
-
67
- system_prompt = (
68
- st.sidebar.text_area(
69
- "Custom Instructions",
70
- _DEFAULT_SYSTEM_PROMPT,
71
- help="Custom instructions to provide the language model to determine style, personality, etc.",
72
- )
73
- .strip()
74
- .replace("{", "{{")
75
- .replace("}", "}}")
76
  )
 
 
 
 
77
 
78
- temperature = st.sidebar.slider(
79
- "Temperature",
80
- min_value=_MIN_TEMPERATURE,
81
- max_value=_MAX_TEMPERATURE,
82
- value=_DEFAULT_TEMPERATURE,
83
- help="Higher values give more random results.",
 
 
 
 
84
  )
 
 
 
 
85
 
86
- max_tokens = st.sidebar.slider(
87
- "Max Tokens",
88
- min_value=_MIN_TOKENS,
89
- max_value=_MAX_TOKENS,
90
- value=_DEFAULT_MAX_TOKENS,
91
- help="Higher values give longer results.",
92
- )
93
 
 
 
 
 
 
 
 
 
 
94
  chain = get_llm_chain(
95
  model,
96
  provider_api_key,
@@ -99,29 +101,30 @@ if provider_api_key := st.sidebar.text_input(f"{provider} API key", type="passwo
99
  max_tokens,
100
  )
101
 
102
- run_collector = RunCollectorCallbackHandler()
103
 
104
- if st.sidebar.button("Clear message history"):
105
- print("Clearing message history")
106
- chain.memory.clear()
107
- st.session_state.trace_link = None
108
- st.session_state.run_id = None
 
 
 
 
 
 
 
109
 
110
- for msg in st.session_state.langchain_messages:
111
- with st.chat_message(msg.type, avatar="🦜" if msg.type == "assistant" else None):
112
- st.markdown(msg.content)
113
 
114
- if client and st.session_state.trace_link:
115
- st.sidebar.markdown(
116
- f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
117
- unsafe_allow_html=True,
118
- )
119
 
120
- def _reset_feedback():
121
- st.session_state.feedback_update = None
122
- st.session_state.feedback = None
123
 
124
- if prompt := st.chat_input(placeholder="Ask me a question!"):
 
 
125
  st.chat_message("user").write(prompt)
126
  _reset_feedback()
127
 
@@ -133,17 +136,10 @@ if provider_api_key := st.sidebar.text_input(f"{provider} API key", type="passwo
133
  tags=["Streamlit Chat"],
134
  )
135
  try:
136
- if client and langsmith_project:
137
- with tracing_v2_enabled(project_name=langsmith_project):
138
- full_response = chain.invoke(
139
- {"input": prompt},
140
- config=runnable_config,
141
- )["text"]
142
- else:
143
- full_response = chain.invoke(
144
- {"input": prompt},
145
- config=runnable_config,
146
- )["text"]
147
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
148
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
149
  st.stop()
@@ -156,10 +152,14 @@ if provider_api_key := st.sidebar.text_input(f"{provider} API key", type="passwo
156
  wait_for_all_tracers()
157
  url = client.read_run(run.id).url
158
  st.session_state.trace_link = url
159
-
160
  if client and st.session_state.get("run_id"):
161
  feedback_component(client)
162
 
163
  else:
164
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
165
- st.stop()
 
 
 
 
 
 
3
  import anthropic
4
  import openai
5
  import streamlit as st
 
6
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
7
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
8
  from langchain.schema.runnable import RunnableConfig
9
  from langsmith.client import Client
10
 
11
  from llm_stuff import (
12
+ _MEMORY,
13
  _MODEL_DICT,
14
  _SUPPORTED_MODELS,
15
  _DEFAULT_MODEL,
 
45
  )
46
  provider = _MODEL_DICT[model]
47
 
48
+ provider_api_key = st.sidebar.text_input(f"{provider} API key", type="password")
49
+ langsmith_api_key = st.sidebar.text_input(
50
+ "LangSmith API Key (optional)",
51
+ type="password",
52
+ )
53
+ if langsmith_api_key:
54
+ langsmith_project = st.sidebar.text_input(
55
+ "LangSmith Project Name",
56
+ value="langchain-streamlit-demo",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
59
+ os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
60
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
61
+ os.environ["LANGCHAIN_PROJECT"] = langsmith_project
62
 
63
+ client = Client(api_key=langsmith_api_key)
64
+ else:
65
+ langsmith_project = None
66
+ client = None
67
+
68
+ system_prompt = (
69
+ st.sidebar.text_area(
70
+ "Custom Instructions",
71
+ _DEFAULT_SYSTEM_PROMPT,
72
+ help="Custom instructions to provide the language model to determine style, personality, etc.",
73
  )
74
+ .strip()
75
+ .replace("{", "{{")
76
+ .replace("}", "}}")
77
+ )
78
 
79
+ temperature = st.sidebar.slider(
80
+ "Temperature",
81
+ min_value=_MIN_TEMPERATURE,
82
+ max_value=_MAX_TEMPERATURE,
83
+ value=_DEFAULT_TEMPERATURE,
84
+ help="Higher values give more random results.",
85
+ )
86
 
87
+ max_tokens = st.sidebar.slider(
88
+ "Max Tokens",
89
+ min_value=_MIN_TOKENS,
90
+ max_value=_MAX_TOKENS,
91
+ value=_DEFAULT_MAX_TOKENS,
92
+ help="Higher values give longer results.",
93
+ )
94
+ chain = None
95
+ if provider_api_key:
96
  chain = get_llm_chain(
97
  model,
98
  provider_api_key,
 
101
  max_tokens,
102
  )
103
 
104
+ run_collector = RunCollectorCallbackHandler()
105
 
106
+ if st.sidebar.button("Clear message history"):
107
+ print("Clearing message history")
108
+ st.session_state["langchain_messages"].memory.clear()
109
+ st.session_state.trace_link = None
110
+ st.session_state.run_id = None
111
+
112
+ for msg in _MEMORY.messages:
113
+ with st.chat_message(
114
+ msg.type,
115
+ avatar="🦜" if msg.type in ("ai", "assistant") else None,
116
+ ):
117
+ st.markdown(msg.content)
118
 
 
 
 
119
 
120
+ def _reset_feedback():
121
+ st.session_state.feedback_update = None
122
+ st.session_state.feedback = None
 
 
123
 
 
 
 
124
 
125
+ if chain:
126
+ prompt = st.chat_input(placeholder="Ask me a question!")
127
+ if prompt:
128
  st.chat_message("user").write(prompt)
129
  _reset_feedback()
130
 
 
136
  tags=["Streamlit Chat"],
137
  )
138
  try:
139
+ full_response = chain.invoke(
140
+ {"input": prompt},
141
+ config=runnable_config,
142
+ )["text"]
 
 
 
 
 
 
 
143
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
144
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
145
  st.stop()
 
152
  wait_for_all_tracers()
153
  url = client.read_run(run.id).url
154
  st.session_state.trace_link = url
 
155
  if client and st.session_state.get("run_id"):
156
  feedback_component(client)
157
 
158
  else:
159
  st.error(f"Please enter a valid {provider} API key.", icon="❌")
160
+
161
+ if client and st.session_state.get("trace_link"):
162
+ st.sidebar.markdown(
163
+ f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
164
+ unsafe_allow_html=True,
165
+ )
langchain-streamlit-demo/llm_stuff.py CHANGED
@@ -10,6 +10,8 @@ from langchain.memory.chat_memory import BaseChatMemory
10
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
11
  from streamlit_feedback import streamlit_feedback
12
 
 
 
13
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
14
 
15
  _MODEL_DICT = {
@@ -35,7 +37,7 @@ _MAX_TOKENS = 100000
35
 
36
  def get_memory() -> BaseChatMemory:
37
  return ConversationBufferMemory(
38
- chat_memory=StreamlitChatMessageHistory(key="langchain_messages"),
39
  return_messages=True,
40
  memory_key="chat_history",
41
  )
@@ -109,7 +111,7 @@ class StreamHandler(BaseCallbackHandler):
109
 
110
 
111
  def feedback_component(client):
112
- scores = {"πŸ˜€": 1, "πŸ™‚": 0.0, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
113
  if feedback := streamlit_feedback(
114
  feedback_type="faces",
115
  optional_text_label="[Optional] Please provide an explanation",
 
10
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
11
  from streamlit_feedback import streamlit_feedback
12
 
13
+ _MEMORY = StreamlitChatMessageHistory(key="langchain_messages")
14
+
15
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
16
 
17
  _MODEL_DICT = {
 
37
 
38
  def get_memory() -> BaseChatMemory:
39
  return ConversationBufferMemory(
40
+ chat_memory=_MEMORY,
41
  return_messages=True,
42
  memory_key="chat_history",
43
  )
 
111
 
112
 
113
  def feedback_component(client):
114
+ scores = {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
115
  if feedback := streamlit_feedback(
116
  feedback_type="faces",
117
  optional_text_label="[Optional] Please provide an explanation",