Spaces:
Running
Running
Fangrui Liu
commited on
Commit
Β·
042a946
1
Parent(s):
e1383d0
update session model
Browse files- app.py +10 -1
- chat.py +158 -14
- helper.py +23 -37
- lib/schemas.py +52 -0
- lib/sessions.py +68 -0
app.py
CHANGED
@@ -10,13 +10,22 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
|
10 |
|
11 |
from chat import chat_page
|
12 |
from login import login, back_to_main
|
|
|
13 |
|
14 |
|
15 |
-
from helper import build_tools, build_agents, build_all, sel_map, display
|
16 |
|
17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
18 |
|
19 |
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
st.header("ChatData")
|
21 |
|
22 |
if 'retriever' not in st.session_state:
|
|
|
10 |
|
11 |
from chat import chat_page
|
12 |
from login import login, back_to_main
|
13 |
+
from helper import build_tools, build_agents, build_all, sel_map, display
|
14 |
|
15 |
|
|
|
16 |
|
17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
18 |
|
19 |
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
|
20 |
+
st.markdown(
|
21 |
+
f"""
|
22 |
+
<style>
|
23 |
+
.st-e4 {{
|
24 |
+
max-width: 500px
|
25 |
+
}}
|
26 |
+
</style>""",
|
27 |
+
unsafe_allow_html=True,
|
28 |
+
)
|
29 |
st.header("ChatData")
|
30 |
|
31 |
if 'retriever' not in st.session_state:
|
chat.py
CHANGED
@@ -1,20 +1,37 @@
|
|
1 |
import pandas as pd
|
2 |
from os import environ
|
|
|
3 |
import datetime
|
4 |
import streamlit as st
|
|
|
5 |
from langchain.schema import HumanMessage, FunctionMessage
|
6 |
|
7 |
-
from helper import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from login import back_to_main
|
9 |
|
10 |
-
environ[
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def on_chat_submit():
|
13 |
-
ret = st.session_state.
|
14 |
print(ret)
|
15 |
-
|
|
|
16 |
def clear_history():
|
17 |
-
st.session_state
|
|
|
18 |
|
19 |
|
20 |
def back_to_main():
|
@@ -25,29 +42,156 @@ def back_to_main():
|
|
25 |
if "jump_query_ask" in st.session_state:
|
26 |
del st.session_state.jump_query_ask
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def chat_page():
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
with st.sidebar:
|
31 |
-
st.
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
st.button("Clear Chat History", on_click=clear_history)
|
34 |
st.button("Logout", on_click=back_to_main)
|
35 |
-
|
|
|
|
|
|
|
36 |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
37 |
if isinstance(msg, FunctionMessage):
|
38 |
with st.chat_message("Knowledge Base", avatar="π"):
|
39 |
-
|
40 |
-
|
|
|
41 |
st.write("Retrieved from knowledge base:")
|
42 |
try:
|
43 |
-
st.dataframe(
|
|
|
|
|
44 |
except:
|
45 |
st.write(msg.content)
|
46 |
else:
|
47 |
if len(msg.content) > 0:
|
48 |
with st.chat_message(speaker):
|
49 |
print(type(msg), msg.dict())
|
50 |
-
st.write(
|
|
|
|
|
51 |
st.write(f"{msg.content}")
|
52 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
53 |
-
|
|
|
1 |
import pandas as pd
|
2 |
from os import environ
|
3 |
+
from time import sleep
|
4 |
import datetime
|
5 |
import streamlit as st
|
6 |
+
from lib.sessions import SessionManager
|
7 |
from langchain.schema import HumanMessage, FunctionMessage
|
8 |
|
9 |
+
from helper import (
|
10 |
+
build_agents,
|
11 |
+
MYSCALE_HOST,
|
12 |
+
MYSCALE_PASSWORD,
|
13 |
+
MYSCALE_PORT,
|
14 |
+
MYSCALE_USER,
|
15 |
+
DEFAULT_SYSTEM_PROMPT,
|
16 |
+
)
|
17 |
from login import back_to_main
|
18 |
|
19 |
+
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
|
20 |
+
|
21 |
+
TOOL_NAMES = {
|
22 |
+
"langchain_retriever_tool": "Self-querying retriever",
|
23 |
+
"vecsql_retriever_tool": "Vector SQL",
|
24 |
+
}
|
25 |
+
|
26 |
|
27 |
def on_chat_submit():
|
28 |
+
ret = st.session_state.agent({"input": st.session_state.chat_input})
|
29 |
print(ret)
|
30 |
+
|
31 |
+
|
32 |
def clear_history():
|
33 |
+
if "agent" in st.session_state:
|
34 |
+
st.session_state.agent.memory.clear()
|
35 |
|
36 |
|
37 |
def back_to_main():
|
|
|
42 |
if "jump_query_ask" in st.session_state:
|
43 |
del st.session_state.jump_query_ask
|
44 |
|
45 |
+
|
46 |
+
def on_session_change_submit():
|
47 |
+
if "session_manager" in st.session_state and "session_editor" in st.session_state:
|
48 |
+
print(st.session_state.session_editor)
|
49 |
+
try:
|
50 |
+
for elem in st.session_state.session_editor["added_rows"]:
|
51 |
+
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
|
52 |
+
if elem["session_id"] != "" and "?" not in elem["session_id"]:
|
53 |
+
st.session_state.session_manager.add_session(
|
54 |
+
user_id=st.session_state.user_name,
|
55 |
+
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
|
56 |
+
system_prompt=elem["system_prompt"],
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
raise KeyError(
|
60 |
+
"`session_id` should NOT be neither empty nor contain question marks."
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
raise KeyError(
|
64 |
+
"You should fill both `session_id` and `system_prompt` to add a column!"
|
65 |
+
)
|
66 |
+
for elem in st.session_state.session_editor["deleted_rows"]:
|
67 |
+
st.session_state.session_manager.remove_session(
|
68 |
+
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
|
69 |
+
)
|
70 |
+
refresh_sessions()
|
71 |
+
if len(st.session_state.session_editor["deleted_rows"]) > 0:
|
72 |
+
try:
|
73 |
+
dfl_indx = [
|
74 |
+
x["session_id"] for x in st.session_state.current_sessions
|
75 |
+
].index("default")
|
76 |
+
except ValueError:
|
77 |
+
dfl_indx = 0
|
78 |
+
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
79 |
+
except Exception as e:
|
80 |
+
sleep(2)
|
81 |
+
st.error(f"{type(e)}: {str(e)}")
|
82 |
+
finally:
|
83 |
+
st.session_state.session_editor["added_rows"] = []
|
84 |
+
st.session_state.session_editor["deleted_rows"] = []
|
85 |
+
refresh_agent()
|
86 |
+
|
87 |
+
|
88 |
+
def build_session_manager():
|
89 |
+
return SessionManager(
|
90 |
+
host=MYSCALE_HOST,
|
91 |
+
port=MYSCALE_PORT,
|
92 |
+
username=MYSCALE_USER,
|
93 |
+
password=MYSCALE_PASSWORD,
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def refresh_sessions():
|
98 |
+
st.session_state[
|
99 |
+
"current_sessions"
|
100 |
+
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
101 |
+
if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
|
102 |
+
st.session_state.session_manager.add_session(
|
103 |
+
st.session_state.user_name,
|
104 |
+
f"{st.session_state.user_name}?default",
|
105 |
+
DEFAULT_SYSTEM_PROMPT,
|
106 |
+
)
|
107 |
+
st.session_state[
|
108 |
+
"current_sessions"
|
109 |
+
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
110 |
+
|
111 |
+
|
112 |
+
def refresh_agent():
|
113 |
+
with st.spinner("Initializing session..."):
|
114 |
+
print(
|
115 |
+
f"??? Changed to ",
|
116 |
+
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
|
117 |
+
)
|
118 |
+
st.session_state["agent"] = build_agents(
|
119 |
+
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
|
120 |
+
["LangChain Self Query Retriever For Wikipedia"]
|
121 |
+
if "selected_tools" not in st.session_state
|
122 |
+
else st.session_state.selected_tools,
|
123 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT
|
124 |
+
if "sel_sess" not in st.session_state
|
125 |
+
else st.session_state.sel_sess["system_prompt"],
|
126 |
+
)
|
127 |
+
st.session_state["session_manager"] = build_session_manager()
|
128 |
+
|
129 |
+
|
130 |
def chat_page():
|
131 |
+
if "sel_sess" not in st.session_state:
|
132 |
+
st.session_state["sel_sess"] = {
|
133 |
+
"session_id": "default",
|
134 |
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
135 |
+
}
|
136 |
with st.sidebar:
|
137 |
+
with st.expander("Session Management"):
|
138 |
+
refresh_sessions()
|
139 |
+
st.data_editor(
|
140 |
+
st.session_state.current_sessions,
|
141 |
+
num_rows="dynamic",
|
142 |
+
key="session_editor",
|
143 |
+
use_container_width=True,
|
144 |
+
)
|
145 |
+
st.button("Submit Change!", on_click=on_session_change_submit)
|
146 |
+
with st.expander("Session Selection", expanded=True):
|
147 |
+
try:
|
148 |
+
dfl_indx = [
|
149 |
+
x["session_id"] for x in st.session_state.current_sessions
|
150 |
+
].index("default")
|
151 |
+
except ValueError:
|
152 |
+
dfl_indx = 0
|
153 |
+
st.selectbox(
|
154 |
+
"Choose a session be chat:",
|
155 |
+
options=st.session_state.current_sessions,
|
156 |
+
index=dfl_indx,
|
157 |
+
key="sel_sess",
|
158 |
+
format_func=lambda x: x["session_id"],
|
159 |
+
on_change=refresh_agent,
|
160 |
+
)
|
161 |
+
print(st.session_state.sel_sess)
|
162 |
+
with st.expander("Tool Settings", expanded=True):
|
163 |
+
st.multiselect(
|
164 |
+
"Knowledge Base",
|
165 |
+
st.session_state.tools.keys(),
|
166 |
+
default=["LangChain Self Query Retriever For Wikipedia"],
|
167 |
+
key="selected_tools",
|
168 |
+
on_change=refresh_agent,
|
169 |
+
)
|
170 |
st.button("Clear Chat History", on_click=clear_history)
|
171 |
st.button("Logout", on_click=back_to_main)
|
172 |
+
if 'agent' not in st.session_state:
|
173 |
+
refresh_agent()
|
174 |
+
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
175 |
+
for msg in st.session_state.agent.memory.chat_memory.messages:
|
176 |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
177 |
if isinstance(msg, FunctionMessage):
|
178 |
with st.chat_message("Knowledge Base", avatar="π"):
|
179 |
+
st.write(
|
180 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
181 |
+
)
|
182 |
st.write("Retrieved from knowledge base:")
|
183 |
try:
|
184 |
+
st.dataframe(
|
185 |
+
pd.DataFrame.from_records(map(dict, eval(msg.content)))
|
186 |
+
)
|
187 |
except:
|
188 |
st.write(msg.content)
|
189 |
else:
|
190 |
if len(msg.content) > 0:
|
191 |
with st.chat_message(speaker):
|
192 |
print(type(msg), msg.dict())
|
193 |
+
st.write(
|
194 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
195 |
+
)
|
196 |
st.write(f"{msg.content}")
|
197 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
|
helper.py
CHANGED
@@ -68,6 +68,12 @@ MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
|
68 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
69 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
70 |
(HumanMessagePromptTemplate, '{question}')])
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def hint_arxiv():
|
73 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
@@ -415,7 +421,7 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
415 |
return self.model_class
|
416 |
|
417 |
|
418 |
-
def create_agent_executor(name, session_id, llm, tools, **kwargs):
|
419 |
name = name.replace(" ", "_")
|
420 |
conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
|
421 |
chat_memory = SQLChatMessageHistory(
|
@@ -425,12 +431,7 @@ def create_agent_executor(name, session_id, llm, tools, **kwargs):
|
|
425 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
426 |
|
427 |
_system_message = SystemMessage(
|
428 |
-
content=
|
429 |
-
"Do your best to answer the questions. "
|
430 |
-
"Feel free to use any tools available to look up "
|
431 |
-
"relevant information. Please keep all details in query "
|
432 |
-
"when calling search functions."
|
433 |
-
)
|
434 |
)
|
435 |
prompt = OpenAIFunctionsAgent.create_prompt(
|
436 |
system_message=_system_message,
|
@@ -463,38 +464,23 @@ def build_tools():
|
|
463 |
st.session_state["sel_map_obj"][k] = {}
|
464 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
465 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
466 |
-
sel_map_obj
|
467 |
-
"
|
468 |
-
"
|
469 |
-
}
|
470 |
return sel_map_obj
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
tools = []
|
484 |
-
else:
|
485 |
-
tools = [st.session_state.tools[k][m]]
|
486 |
-
if k not in agents:
|
487 |
-
agents[k] = {}
|
488 |
-
agents[k][n] = create_agent_executor(
|
489 |
-
"chat_memory",
|
490 |
-
session_id,
|
491 |
-
chat_llm,
|
492 |
-
tools=tools,
|
493 |
-
)
|
494 |
-
cnt += 1/6
|
495 |
-
p.progress(cnt, f"Building with Knowledge Base {k} via Retriever {n}...")
|
496 |
-
p.empty()
|
497 |
-
return agents
|
498 |
|
499 |
|
500 |
def display(dataframe, columns_=None, index=None):
|
|
|
68 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
69 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
70 |
(HumanMessagePromptTemplate, '{question}')])
|
71 |
+
DEFAULT_SYSTEM_PROMPT = (
|
72 |
+
"Do your best to answer the questions. "
|
73 |
+
"Feel free to use any tools available to look up "
|
74 |
+
"relevant information. Please keep all details in query "
|
75 |
+
"when calling search functions."
|
76 |
+
)
|
77 |
|
78 |
def hint_arxiv():
|
79 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
|
|
421 |
return self.model_class
|
422 |
|
423 |
|
424 |
+
def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs):
|
425 |
name = name.replace(" ", "_")
|
426 |
conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
|
427 |
chat_memory = SQLChatMessageHistory(
|
|
|
431 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
432 |
|
433 |
_system_message = SystemMessage(
|
434 |
+
content=system_prompt
|
|
|
|
|
|
|
|
|
|
|
435 |
)
|
436 |
prompt = OpenAIFunctionsAgent.create_prompt(
|
437 |
system_message=_system_message,
|
|
|
464 |
st.session_state["sel_map_obj"][k] = {}
|
465 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
466 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
467 |
+
sel_map_obj.update({
|
468 |
+
f"LangChain Self Query Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
|
469 |
+
f"Vector SQL Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
|
470 |
+
})
|
471 |
return sel_map_obj
|
472 |
|
473 |
+
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
474 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
|
475 |
+
tools = [st.session_state.tools[k] for k in tool_names]
|
476 |
+
agent = create_agent_executor(
|
477 |
+
"chat_memory",
|
478 |
+
session_id,
|
479 |
+
chat_llm,
|
480 |
+
tools=tools,
|
481 |
+
system_prompt=system_prompt
|
482 |
+
)
|
483 |
+
return agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
|
485 |
|
486 |
def display(dataframe, columns_=None, index=None):
|
lib/schemas.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sqlalchemy import Column, Text
|
2 |
+
from clickhouse_sqlalchemy import types, engines
|
3 |
+
|
4 |
+
|
5 |
+
def create_message_model(table_name, DynamicBase): # type: ignore
|
6 |
+
"""
|
7 |
+
Create a message model for a given table name.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
table_name: The name of the table to use.
|
11 |
+
DynamicBase: The base class to use for the model.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
The model class.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Model decleared inside a function to have a dynamic table name
|
19 |
+
class Message(DynamicBase):
|
20 |
+
__tablename__ = table_name
|
21 |
+
id = Column(types.Float64)
|
22 |
+
session_id = Column(Text)
|
23 |
+
user_id = Column(Text)
|
24 |
+
msg_id = Column(Text, primary_key=True)
|
25 |
+
type = Column(Text)
|
26 |
+
addtionals = Column(Text)
|
27 |
+
message = Column(Text)
|
28 |
+
__table_args__ = (
|
29 |
+
engines.ReplacingMergeTree(
|
30 |
+
partition_by='session_id',
|
31 |
+
order_by=('id', 'msg_id')),
|
32 |
+
{'comment': 'Store Chat History'}
|
33 |
+
)
|
34 |
+
|
35 |
+
return Message
|
36 |
+
|
37 |
+
|
38 |
+
def create_session_table(table_name, DynamicBase): # type: ignore
|
39 |
+
# Model decleared inside a function to have a dynamic table name
|
40 |
+
class Session(DynamicBase):
|
41 |
+
__tablename__ = table_name
|
42 |
+
user_id = Column(Text)
|
43 |
+
session_id = Column(Text, primary_key=True)
|
44 |
+
system_prompt = Column(Text)
|
45 |
+
create_by = Column(types.DateTime)
|
46 |
+
additionals = Column(Text)
|
47 |
+
__table_args__ = (
|
48 |
+
engines.ReplacingMergeTree(
|
49 |
+
order_by=('session_id')),
|
50 |
+
{'comment': 'Store Session and Prompts'}
|
51 |
+
)
|
52 |
+
return Session
|
lib/sessions.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
try:
|
3 |
+
from sqlalchemy.orm import declarative_base
|
4 |
+
except ImportError:
|
5 |
+
from sqlalchemy.ext.declarative import declarative_base
|
6 |
+
from datetime import datetime
|
7 |
+
from sqlalchemy import Column, Text, orm, create_engine
|
8 |
+
from clickhouse_sqlalchemy import types, engines
|
9 |
+
from .schemas import create_message_model, create_session_table
|
10 |
+
|
11 |
+
def get_sessions(engine, model_class, user_id):
|
12 |
+
with orm.sessionmaker(engine)() as session:
|
13 |
+
result = (
|
14 |
+
session.query(model_class)
|
15 |
+
.where(
|
16 |
+
model_class.session_id == user_id
|
17 |
+
)
|
18 |
+
.order_by(model_class.create_by.desc())
|
19 |
+
)
|
20 |
+
return json.loads(result)
|
21 |
+
|
22 |
+
class SessionManager:
|
23 |
+
def __init__(self, host, port, username, password, db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
24 |
+
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
25 |
+
self.engine = create_engine(conn_str, echo=False)
|
26 |
+
self.sess_model_class = create_session_table(sess_table, declarative_base())
|
27 |
+
self.sess_model_class.metadata.create_all(self.engine)
|
28 |
+
self.msg_model_class = create_message_model(msg_table, declarative_base())
|
29 |
+
self.msg_model_class.metadata.create_all(self.engine)
|
30 |
+
self.Session = orm.sessionmaker(self.engine)
|
31 |
+
|
32 |
+
def list_sessions(self, user_id):
|
33 |
+
with self.Session() as session:
|
34 |
+
result = (
|
35 |
+
session.query(self.sess_model_class)
|
36 |
+
.where(
|
37 |
+
self.sess_model_class.user_id == user_id
|
38 |
+
)
|
39 |
+
.order_by(self.sess_model_class.create_by.desc())
|
40 |
+
)
|
41 |
+
sessions = []
|
42 |
+
for r in result:
|
43 |
+
sessions.append({
|
44 |
+
"session_id": r.session_id.split("?")[-1],
|
45 |
+
"system_prompt": r.system_prompt,
|
46 |
+
})
|
47 |
+
return sessions
|
48 |
+
|
49 |
+
def modify_system_prompt(self, session_id, sys_prompt):
|
50 |
+
with self.Session() as session:
|
51 |
+
session.update(self.sess_model_class).where(self.sess_model_class==session_id).value(system_prompt=sys_prompt)
|
52 |
+
session.commit()
|
53 |
+
|
54 |
+
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
55 |
+
with self.Session() as session:
|
56 |
+
elem = self.sess_model_class(
|
57 |
+
user_id=user_id, session_id=session_id, system_prompt=system_prompt,
|
58 |
+
create_by=datetime.now(), additionals=json.dumps(kwargs)
|
59 |
+
)
|
60 |
+
session.add(elem)
|
61 |
+
session.commit()
|
62 |
+
|
63 |
+
def remove_session(self, session_id):
|
64 |
+
with self.Session() as session:
|
65 |
+
session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
|
66 |
+
session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
67 |
+
|
68 |
+
|