ChatData / backend /chat_bot /session_manager.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
import json
from backend.chat_bot.tools import create_session_table, create_message_history_table
from backend.constants.variables import GLOBAL_CONFIG
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
from sqlalchemy import orm, create_engine
from logger import logger
def get_sessions(engine, model_class, user_id):
with orm.sessionmaker(engine)() as session:
result = (
session.query(model_class)
.where(
model_class.session_id == user_id
)
.order_by(model_class.create_by.desc())
)
return json.loads(result)
class SessionManager:
def __init__(
self,
session_state,
host,
port,
username,
password,
db='chat',
session_table='sessions',
msg_table='chat_memory'
) -> None:
if GLOBAL_CONFIG.myscale_enable_https == False:
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
else:
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
self.engine = create_engine(conn_str, echo=False)
self.session_model_class = create_session_table(
session_table, declarative_base())
self.session_model_class.metadata.create_all(self.engine)
self.msg_model_class = create_message_history_table(msg_table, declarative_base())
self.msg_model_class.metadata.create_all(self.engine)
self.session_orm = orm.sessionmaker(self.engine)
self.session_state = session_state
def list_sessions(self, user_id: str):
with self.session_orm() as session:
result = (
session.query(self.session_model_class)
.where(
self.session_model_class.user_id == user_id
)
.order_by(self.session_model_class.create_by.desc())
)
sessions = []
for r in result:
sessions.append({
"session_id": r.session_id.split("?")[-1],
"system_prompt": r.system_prompt,
})
return sessions
# Update sys_prompt with given session_id
def modify_system_prompt(self, session_id, sys_prompt):
with self.session_orm() as session:
obj = session.query(self.session_model_class).where(
self.session_model_class.session_id == session_id).first()
if obj:
obj.system_prompt = sys_prompt
session.commit()
else:
logger.warning(f"Session {session_id} not found")
# Add a session(session_id, sys_prompt)
def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
with self.session_orm() as session:
elem = self.session_model_class(
user_id=user_id, session_id=session_id, system_prompt=system_prompt,
create_by=datetime.now(), additionals=json.dumps(kwargs)
)
session.add(elem)
session.commit()
# Remove a session and related chat history.
def remove_session(self, session_id: str):
with self.session_orm() as session:
# remove session
session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
# remove related chat history.
session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()