File size: 3,082 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from typing import Sequence, List

import streamlit as st
from langchain.agents import AgentExecutor
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool

from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
from logger import logger

try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import MessagesPlaceholder
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.memory import SQLChatMessageHistory


def create_agent_executor(
        agent_name: str,
        session_id: str,
        llm: BaseLanguageModel,
        tools: Sequence[BaseTool],
        system_prompt: str,
        **kwargs
) -> AgentExecutor:
    agent_name = agent_name.replace(" ", "_")
    conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
    chat_memory = SQLChatMessageHistory(
        session_id,
        connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
        custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
    memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)

    prompt = OpenAIFunctionsAgent.create_prompt(
        system_message=SystemMessage(content=system_prompt),
        extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
    )
    agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
    return AgentExecutor(
        agent=agent,
        tools=tools,
        memory=memory,
        verbose=True,
        return_intermediate_steps=True,
        **kwargs
    )


def build_agents(
        session_id: str,
        tool_names: List[str],
        model: str = "gpt-3.5-turbo-0125",
        temperature: float = 0.6,
        system_prompt: str = DEFAULT_SYSTEM_PROMPT
):
    chat_llm = ChatOpenAI(
        model_name=model,
        temperature=temperature,
        base_url=GLOBAL_CONFIG.openai_api_base,
        api_key=GLOBAL_CONFIG.openai_api_key,
        streaming=True
    )
    tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
    selected_tools = [tools[k] for k in tool_names]
    logger.info(f"create agent, use tools: {selected_tools}")
    agent = create_agent_executor(
        agent_name="chat_memory",
        session_id=session_id,
        llm=chat_llm,
        tools=selected_tools,
        system_prompt=system_prompt
    )
    return agent