Spaces:
Runtime error
Runtime error
| import json | |
| from backend.chat_bot.tools import create_session_table, create_message_history_table | |
| from backend.constants.variables import GLOBAL_CONFIG | |
| try: | |
| from sqlalchemy.orm import declarative_base | |
| except ImportError: | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from datetime import datetime | |
| from sqlalchemy import orm, create_engine | |
| from logger import logger | |
| def get_sessions(engine, model_class, user_id): | |
| with orm.sessionmaker(engine)() as session: | |
| result = ( | |
| session.query(model_class) | |
| .where( | |
| model_class.session_id == user_id | |
| ) | |
| .order_by(model_class.create_by.desc()) | |
| ) | |
| return json.loads(result) | |
| class SessionManager: | |
| def __init__( | |
| self, | |
| session_state, | |
| host, | |
| port, | |
| username, | |
| password, | |
| db='chat', | |
| session_table='sessions', | |
| msg_table='chat_memory' | |
| ) -> None: | |
| if GLOBAL_CONFIG.myscale_enable_https == False: | |
| conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http' | |
| else: | |
| conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https' | |
| self.engine = create_engine(conn_str, echo=False) | |
| self.session_model_class = create_session_table( | |
| session_table, declarative_base()) | |
| self.session_model_class.metadata.create_all(self.engine) | |
| self.msg_model_class = create_message_history_table(msg_table, declarative_base()) | |
| self.msg_model_class.metadata.create_all(self.engine) | |
| self.session_orm = orm.sessionmaker(self.engine) | |
| self.session_state = session_state | |
| def list_sessions(self, user_id: str): | |
| with self.session_orm() as session: | |
| result = ( | |
| session.query(self.session_model_class) | |
| .where( | |
| self.session_model_class.user_id == user_id | |
| ) | |
| .order_by(self.session_model_class.create_by.desc()) | |
| ) | |
| sessions = [] | |
| for r in result: | |
| sessions.append({ | |
| "session_id": r.session_id.split("?")[-1], | |
| "system_prompt": r.system_prompt, | |
| }) | |
| return sessions | |
| # Update sys_prompt with given session_id | |
| def modify_system_prompt(self, session_id, sys_prompt): | |
| with self.session_orm() as session: | |
| obj = session.query(self.session_model_class).where( | |
| self.session_model_class.session_id == session_id).first() | |
| if obj: | |
| obj.system_prompt = sys_prompt | |
| session.commit() | |
| else: | |
| logger.warning(f"Session {session_id} not found") | |
| # Add a session(session_id, sys_prompt) | |
| def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs): | |
| with self.session_orm() as session: | |
| elem = self.session_model_class( | |
| user_id=user_id, session_id=session_id, system_prompt=system_prompt, | |
| create_by=datetime.now(), additionals=json.dumps(kwargs) | |
| ) | |
| session.add(elem) | |
| session.commit() | |
| # Remove a session and related chat history. | |
| def remove_session(self, session_id: str): | |
| with self.session_orm() as session: | |
| # remove session | |
| session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete() | |
| # remove related chat history. | |
| session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete() | |