Spaces:
Running
Running
improve chat experience
Browse files- app.py +1 -1
- callbacks/arxiv_callbacks.py +32 -3
- chat.py +23 -4
- helper.py +66 -8
app.py
CHANGED
@@ -28,7 +28,7 @@ st.markdown(
|
|
28 |
)
|
29 |
st.header("ChatData")
|
30 |
|
31 |
-
if '
|
32 |
st.session_state["sel_map_obj"] = build_all()
|
33 |
st.session_state["tools"] = build_tools()
|
34 |
|
|
|
28 |
)
|
29 |
st.header("ChatData")
|
30 |
|
31 |
+
if 'sel_map_obj' not in st.session_state:
|
32 |
st.session_state["sel_map_obj"] = build_all()
|
33 |
st.session_state["tools"] = build_tools()
|
34 |
|
callbacks/arxiv_callbacks.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
|
|
|
|
3 |
from sql_formatter.core import format_sql
|
4 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
5 |
from langchain.schema.output import LLMResult
|
|
|
6 |
|
7 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
8 |
def __init__(self) -> None:
|
@@ -91,4 +94,30 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
|
91 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
92 |
self.status_bar = st.empty()
|
93 |
self.prog_value = 0
|
94 |
-
self.prog_interval = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import json
|
3 |
+
import textwrap
|
4 |
+
from typing import Dict, Any, List
|
5 |
from sql_formatter.core import format_sql
|
6 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import LLMThought, StreamlitCallbackHandler
|
7 |
from langchain.schema.output import LLMResult
|
8 |
+
from streamlit.delta_generator import DeltaGenerator
|
9 |
|
10 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
11 |
def __init__(self) -> None:
|
|
|
94 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
95 |
self.status_bar = st.empty()
|
96 |
self.prog_value = 0
|
97 |
+
self.prog_interval = 0.1
|
98 |
+
|
99 |
+
|
100 |
+
class LLMThoughtWithKB(LLMThought):
|
101 |
+
def on_tool_end(self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any) -> None:
|
102 |
+
try:
|
103 |
+
self._container.markdown("\n\n".join(["### Retrieved Documents:"] + \
|
104 |
+
[f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
|
105 |
+
for i, r in enumerate(json.loads(output))]))
|
106 |
+
except Exception as e:
|
107 |
+
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
108 |
+
|
109 |
+
|
110 |
+
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
111 |
+
|
112 |
+
def on_llm_start(
|
113 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
114 |
+
) -> None:
|
115 |
+
if self._current_thought is None:
|
116 |
+
self._current_thought = LLMThoughtWithKB(
|
117 |
+
parent_container=self._parent_container,
|
118 |
+
expanded=self._expand_new_thoughts,
|
119 |
+
collapse_on_complete=self._collapse_completed_thoughts,
|
120 |
+
labeler=self._thought_labeler,
|
121 |
+
)
|
122 |
+
|
123 |
+
self._current_thought.on_llm_start(serialized, prompts)
|
chat.py
CHANGED
@@ -5,6 +5,8 @@ import datetime
|
|
5 |
import streamlit as st
|
6 |
from lib.sessions import SessionManager
|
7 |
from langchain.schema import HumanMessage, FunctionMessage
|
|
|
|
|
8 |
|
9 |
from helper import (
|
10 |
build_agents,
|
@@ -25,8 +27,14 @@ TOOL_NAMES = {
|
|
25 |
|
26 |
|
27 |
def on_chat_submit():
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
def clear_history():
|
@@ -136,6 +144,12 @@ def chat_page():
|
|
136 |
with st.sidebar:
|
137 |
with st.expander("Session Management"):
|
138 |
refresh_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
st.data_editor(
|
140 |
st.session_state.current_sessions,
|
141 |
num_rows="dynamic",
|
@@ -144,6 +158,8 @@ def chat_page():
|
|
144 |
)
|
145 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
146 |
with st.expander("Session Selection", expanded=True):
|
|
|
|
|
147 |
try:
|
148 |
dfl_indx = [
|
149 |
x["session_id"] for x in st.session_state.current_sessions
|
@@ -152,7 +168,7 @@ def chat_page():
|
|
152 |
print("*** ", str(e))
|
153 |
dfl_indx = 0
|
154 |
st.selectbox(
|
155 |
-
"Choose a session
|
156 |
options=st.session_state.current_sessions,
|
157 |
index=dfl_indx,
|
158 |
key="sel_sess",
|
@@ -161,10 +177,12 @@ def chat_page():
|
|
161 |
)
|
162 |
print(st.session_state.sel_sess)
|
163 |
with st.expander("Tool Settings", expanded=True):
|
|
|
|
|
164 |
st.multiselect(
|
165 |
"Knowledge Base",
|
166 |
st.session_state.tools.keys(),
|
167 |
-
default=["
|
168 |
key="selected_tools",
|
169 |
on_change=refresh_agent,
|
170 |
)
|
@@ -195,4 +213,5 @@ def chat_page():
|
|
195 |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
196 |
)
|
197 |
st.write(f"{msg.content}")
|
|
|
198 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
|
|
5 |
import streamlit as st
|
6 |
from lib.sessions import SessionManager
|
7 |
from langchain.schema import HumanMessage, FunctionMessage
|
8 |
+
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
9 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
10 |
|
11 |
from helper import (
|
12 |
build_agents,
|
|
|
27 |
|
28 |
|
29 |
def on_chat_submit():
|
30 |
+
with st.session_state.next_round.container():
|
31 |
+
with st.chat_message('user'):
|
32 |
+
st.write(st.session_state.chat_input)
|
33 |
+
with st.chat_message('assistant'):
|
34 |
+
container = st.container()
|
35 |
+
st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
|
36 |
+
ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
|
37 |
+
print(ret)
|
38 |
|
39 |
|
40 |
def clear_history():
|
|
|
144 |
with st.sidebar:
|
145 |
with st.expander("Session Management"):
|
146 |
refresh_sessions()
|
147 |
+
st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
|
148 |
+
icon="π€")
|
149 |
+
st.info(("**Add columns by clicking the empty row**.\n"
|
150 |
+
"And **delete columns by selecting rows with a press on `DEL` Key**"),
|
151 |
+
icon="π‘")
|
152 |
+
st.info("Don't forget to **click `Submit Change` to save your change**!", icon="π")
|
153 |
st.data_editor(
|
154 |
st.session_state.current_sessions,
|
155 |
num_rows="dynamic",
|
|
|
158 |
)
|
159 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
160 |
with st.expander("Session Selection", expanded=True):
|
161 |
+
st.info("Here you can select your session!", icon="π€")
|
162 |
+
st.info("If no session is attach to your account, then we will add a default session to you!", icon="β€οΈ")
|
163 |
try:
|
164 |
dfl_indx = [
|
165 |
x["session_id"] for x in st.session_state.current_sessions
|
|
|
168 |
print("*** ", str(e))
|
169 |
dfl_indx = 0
|
170 |
st.selectbox(
|
171 |
+
"Choose a session to chat:",
|
172 |
options=st.session_state.current_sessions,
|
173 |
index=dfl_indx,
|
174 |
key="sel_sess",
|
|
|
177 |
)
|
178 |
print(st.session_state.sel_sess)
|
179 |
with st.expander("Tool Settings", expanded=True):
|
180 |
+
st.info("Here you can select your tools.", icon="π§")
|
181 |
+
st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="π·ββοΈ")
|
182 |
st.multiselect(
|
183 |
"Knowledge Base",
|
184 |
st.session_state.tools.keys(),
|
185 |
+
default=["Wikipedia + Self Querying"],
|
186 |
key="selected_tools",
|
187 |
on_change=refresh_agent,
|
188 |
)
|
|
|
213 |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
214 |
)
|
215 |
st.write(f"{msg.content}")
|
216 |
+
st.session_state["next_round"] = st.empty()
|
217 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
helper.py
CHANGED
@@ -2,12 +2,15 @@
|
|
2 |
import json
|
3 |
import time
|
4 |
import hashlib
|
5 |
-
from typing import Dict, Any
|
6 |
import re
|
7 |
import pandas as pd
|
8 |
from os import environ
|
9 |
import streamlit as st
|
10 |
import datetime
|
|
|
|
|
|
|
11 |
|
12 |
from sqlalchemy import Column, Text, create_engine, MetaData
|
13 |
from langchain.agents import AgentExecutor
|
@@ -28,7 +31,7 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
|
28 |
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
29 |
from langchain.prompts.prompt import PromptTemplate
|
30 |
from langchain.chat_models import ChatOpenAI
|
31 |
-
from langchain.schema import BaseRetriever
|
32 |
from langchain import OpenAI
|
33 |
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
34 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
@@ -36,12 +39,12 @@ from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
|
36 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
37 |
from langchain.vectorstores import MyScaleSettings
|
38 |
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
39 |
-
from langchain.schema import Document
|
40 |
from langchain.prompts.prompt import PromptTemplate
|
41 |
from langchain.prompts.chat import MessagesPlaceholder
|
42 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
43 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
44 |
-
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage
|
|
|
45 |
from langchain.memory import SQLChatMessageHistory
|
46 |
from langchain.memory.chat_message_histories.sql import \
|
47 |
BaseMessageConverter, DefaultMessageConverter
|
@@ -389,6 +392,26 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
|
389 |
|
390 |
return Message
|
391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
393 |
"""The default message converter for SQLChatMessageHistory."""
|
394 |
|
@@ -411,9 +434,10 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
411 |
"additional_kwargs": {"timestamp": tstamp},
|
412 |
"data": message.dict()})
|
413 |
)
|
|
|
414 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
415 |
msg_dump = json.loads(sql_message.message)
|
416 |
-
msg =
|
417 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
418 |
return msg
|
419 |
|
@@ -447,6 +471,38 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
447 |
**kwargs
|
448 |
)
|
449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
@st.cache_resource
|
451 |
def build_tools():
|
452 |
"""build all resources
|
@@ -465,13 +521,15 @@ def build_tools():
|
|
465 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
466 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
467 |
sel_map_obj.update({
|
468 |
-
f"
|
469 |
-
f"Vector SQL
|
470 |
})
|
471 |
return sel_map_obj
|
472 |
|
473 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
474 |
-
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
|
|
|
|
475 |
tools = [st.session_state.tools[k] for k in tool_names]
|
476 |
agent = create_agent_executor(
|
477 |
"chat_memory",
|
|
|
2 |
import json
|
3 |
import time
|
4 |
import hashlib
|
5 |
+
from typing import Dict, Any, List
|
6 |
import re
|
7 |
import pandas as pd
|
8 |
from os import environ
|
9 |
import streamlit as st
|
10 |
import datetime
|
11 |
+
from langchain.schema import BaseRetriever
|
12 |
+
from langchain.tools import Tool
|
13 |
+
from langchain.pydantic_v1 import BaseModel, Field
|
14 |
|
15 |
from sqlalchemy import Column, Text, create_engine, MetaData
|
16 |
from langchain.agents import AgentExecutor
|
|
|
31 |
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
32 |
from langchain.prompts.prompt import PromptTemplate
|
33 |
from langchain.chat_models import ChatOpenAI
|
34 |
+
from langchain.schema import BaseRetriever, Document
|
35 |
from langchain import OpenAI
|
36 |
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
37 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
|
|
39 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
40 |
from langchain.vectorstores import MyScaleSettings
|
41 |
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
|
|
42 |
from langchain.prompts.prompt import PromptTemplate
|
43 |
from langchain.prompts.chat import MessagesPlaceholder
|
44 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
45 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
46 |
+
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage,\
|
47 |
+
SystemMessage, ChatMessage, ToolMessage
|
48 |
from langchain.memory import SQLChatMessageHistory
|
49 |
from langchain.memory.chat_message_histories.sql import \
|
50 |
BaseMessageConverter, DefaultMessageConverter
|
|
|
392 |
|
393 |
return Message
|
394 |
|
395 |
+
def _message_from_dict(message: dict) -> BaseMessage:
|
396 |
+
_type = message["type"]
|
397 |
+
if _type == "human":
|
398 |
+
return HumanMessage(**message["data"])
|
399 |
+
elif _type == "ai":
|
400 |
+
return AIMessage(**message["data"])
|
401 |
+
elif _type == "system":
|
402 |
+
return SystemMessage(**message["data"])
|
403 |
+
elif _type == "chat":
|
404 |
+
return ChatMessage(**message["data"])
|
405 |
+
elif _type == "function":
|
406 |
+
return FunctionMessage(**message["data"])
|
407 |
+
elif _type == "tool":
|
408 |
+
return ToolMessage(**message["data"])
|
409 |
+
elif _type == "AIMessageChunk":
|
410 |
+
message["data"]["type"] = "ai"
|
411 |
+
return AIMessage(**message["data"])
|
412 |
+
else:
|
413 |
+
raise ValueError(f"Got unexpected message type: {_type}")
|
414 |
+
|
415 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
416 |
"""The default message converter for SQLChatMessageHistory."""
|
417 |
|
|
|
434 |
"additional_kwargs": {"timestamp": tstamp},
|
435 |
"data": message.dict()})
|
436 |
)
|
437 |
+
|
438 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
439 |
msg_dump = json.loads(sql_message.message)
|
440 |
+
msg = _message_from_dict(msg_dump)
|
441 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
442 |
return msg
|
443 |
|
|
|
471 |
**kwargs
|
472 |
)
|
473 |
|
474 |
+
class RetrieverInput(BaseModel):
|
475 |
+
query: str = Field(description="query to look up in retriever")
|
476 |
+
|
477 |
+
def create_retriever_tool(
|
478 |
+
retriever: BaseRetriever, name: str, description: str
|
479 |
+
) -> Tool:
|
480 |
+
"""Create a tool to do retrieval of documents.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
retriever: The retriever to use for the retrieval
|
484 |
+
name: The name for the tool. This will be passed to the language model,
|
485 |
+
so should be unique and somewhat descriptive.
|
486 |
+
description: The description for the tool. This will be passed to the language
|
487 |
+
model, so should be descriptive.
|
488 |
+
|
489 |
+
Returns:
|
490 |
+
Tool class to pass to an agent
|
491 |
+
"""
|
492 |
+
def wrap(func):
|
493 |
+
def wrapped_retrieve(*args, **kwargs):
|
494 |
+
docs: List[Document] = func(*args, **kwargs)
|
495 |
+
return json.dumps([d.dict() for d in docs])
|
496 |
+
return wrapped_retrieve
|
497 |
+
|
498 |
+
return Tool(
|
499 |
+
name=name,
|
500 |
+
description=description,
|
501 |
+
func=wrap(retriever.get_relevant_documents),
|
502 |
+
coroutine=retriever.aget_relevant_documents,
|
503 |
+
args_schema=RetrieverInput,
|
504 |
+
)
|
505 |
+
|
506 |
@st.cache_resource
|
507 |
def build_tools():
|
508 |
"""build all resources
|
|
|
521 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
522 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
523 |
sel_map_obj.update({
|
524 |
+
f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
|
525 |
+
f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
|
526 |
})
|
527 |
return sel_map_obj
|
528 |
|
529 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
530 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
531 |
+
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
532 |
+
)
|
533 |
tools = [st.session_state.tools[k] for k in tool_names]
|
534 |
agent = create_agent_executor(
|
535 |
"chat_memory",
|