Spaces:
Running
Running
File size: 3,082 Bytes
e931b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
from typing import Sequence, List
import streamlit as st
from langchain.agents import AgentExecutor
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
from logger import logger
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import MessagesPlaceholder
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.memory import SQLChatMessageHistory
def create_agent_executor(
agent_name: str,
session_id: str,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
system_prompt: str,
**kwargs
) -> AgentExecutor:
agent_name = agent_name.replace(" ", "_")
conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
chat_memory = SQLChatMessageHistory(
session_id,
connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=SystemMessage(content=system_prompt),
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
return AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=True,
return_intermediate_steps=True,
**kwargs
)
def build_agents(
session_id: str,
tool_names: List[str],
model: str = "gpt-3.5-turbo-0125",
temperature: float = 0.6,
system_prompt: str = DEFAULT_SYSTEM_PROMPT
):
chat_llm = ChatOpenAI(
model_name=model,
temperature=temperature,
base_url=GLOBAL_CONFIG.openai_api_base,
api_key=GLOBAL_CONFIG.openai_api_key,
streaming=True
)
tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
selected_tools = [tools[k] for k in tool_names]
logger.info(f"create agent, use tools: {selected_tools}")
agent = create_agent_executor(
agent_name="chat_memory",
session_id=session_id,
llm=chat_llm,
tools=selected_tools,
system_prompt=system_prompt
)
return agent
|