Spaces:
Running
Running
import json | |
from backend.chat_bot.tools import create_session_table, create_message_history_table | |
from backend.constants.variables import GLOBAL_CONFIG | |
try: | |
from sqlalchemy.orm import declarative_base | |
except ImportError: | |
from sqlalchemy.ext.declarative import declarative_base | |
from datetime import datetime | |
from sqlalchemy import orm, create_engine | |
from logger import logger | |
def get_sessions(engine, model_class, user_id): | |
with orm.sessionmaker(engine)() as session: | |
result = ( | |
session.query(model_class) | |
.where( | |
model_class.session_id == user_id | |
) | |
.order_by(model_class.create_by.desc()) | |
) | |
return json.loads(result) | |
class SessionManager: | |
def __init__( | |
self, | |
session_state, | |
host, | |
port, | |
username, | |
password, | |
db='chat', | |
session_table='sessions', | |
msg_table='chat_memory' | |
) -> None: | |
if GLOBAL_CONFIG.myscale_enable_https == False: | |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http' | |
else: | |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' | |
self.engine = create_engine(conn_str, echo=False) | |
self.session_model_class = create_session_table( | |
session_table, declarative_base()) | |
self.session_model_class.metadata.create_all(self.engine) | |
self.msg_model_class = create_message_history_table(msg_table, declarative_base()) | |
self.msg_model_class.metadata.create_all(self.engine) | |
self.session_orm = orm.sessionmaker(self.engine) | |
self.session_state = session_state | |
def list_sessions(self, user_id: str): | |
with self.session_orm() as session: | |
result = ( | |
session.query(self.session_model_class) | |
.where( | |
self.session_model_class.user_id == user_id | |
) | |
.order_by(self.session_model_class.create_by.desc()) | |
) | |
sessions = [] | |
for r in result: | |
sessions.append({ | |
"session_id": r.session_id.split("?")[-1], | |
"system_prompt": r.system_prompt, | |
}) | |
return sessions | |
# Update sys_prompt with given session_id | |
def modify_system_prompt(self, session_id, sys_prompt): | |
with self.session_orm() as session: | |
obj = session.query(self.session_model_class).where( | |
self.session_model_class.session_id == session_id).first() | |
if obj: | |
obj.system_prompt = sys_prompt | |
session.commit() | |
else: | |
logger.warning(f"Session {session_id} not found") | |
# Add a session(session_id, sys_prompt) | |
def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs): | |
with self.session_orm() as session: | |
elem = self.session_model_class( | |
user_id=user_id, session_id=session_id, system_prompt=system_prompt, | |
create_by=datetime.now(), additionals=json.dumps(kwargs) | |
) | |
session.add(elem) | |
session.commit() | |
# Remove a session and related chat history. | |
def remove_session(self, session_id: str): | |
with self.session_orm() as session: | |
# remove session | |
session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete() | |
# remove related chat history. | |
session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete() | |