FaYo
model
d8d694f
import json
import logging
import streamlit as st
import torch
from lagent.actions import ActionExecutor
from lagent.agents.internlm2_agent import Internlm2Protocol
from lagent.schema import ActionReturn, AgentReturn
from lmdeploy import GenerationConfig
from utils.digital_human.digital_human_worker import gen_digital_human_video_in_spinner
from utils.rag.rag_worker import build_rag_prompt
from utils.tts.tts_worker import gen_tts_in_spinner
def prepare_generation_config(skip_special_tokens=True):
gen_config = GenerationConfig(
top_p=0.8,
temperature=0.7,
repetition_penalty=1.005,
skip_special_tokens=skip_special_tokens,
) # top_k=40, min_new_tokens=200
return gen_config
def combine_history(prompt, meta_instruction, history_msg=None, first_input_str=""):
total_prompt = [{"role": "system", "content": meta_instruction}]
if first_input_str != "":
total_prompt.append({"role": "user", "content": first_input_str})
if history_msg is not None:
for message in history_msg:
total_prompt.append({"role": message["role"], "content": message["content"]})
total_prompt.append({"role": "user", "content": prompt})
return [total_prompt]
'''
@st.cache_resource
# def init_handlers(departure_place, delivery_company_name):
def init_handlers():
# from utils.agent.delivery_time_query import DeliveryTimeQueryAction # isort:skip
META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用"
INTERPRETER_CN = (
"你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。"
"当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。"
"这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),"
"复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),"
"文本处理和分析(比如文本解析和自然语言处理),"
"机器学习和数据科学(用于展示模型训练和数据可视化),"
"以及文件操作和数据导入(处理CSV、JSON等格式的文件)。"
)
PLUGIN_CN = (
"你可以使用如下工具:"
"\n{prompt}\n"
"如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! "
"同时注意你可以使用的工具,不要随意捏造!"
)
protocol_handler = Internlm2Protocol(
meta_prompt=META_CN,
interpreter_prompt=INTERPRETER_CN,
plugin_prompt=PLUGIN_CN,
tool=dict(
begin="{start_token}{name}\n",
start_token="<|action_start|>",
name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"),
belong="assistant",
end="<|action_end|>\n",
),
)
action_list = [
DeliveryTimeQueryAction(
departure_place=departure_place,
delivery_company_name=delivery_company_name,
),
]
# plugin_map = {action.name: action for action in action_list}
# plugin_name = [action.name for action in action_list]
# plugin_action = [plugin_map[name] for name in plugin_name]
# action_executor = ActionExecutor(actions=plugin_action)
# return action_executor, protocol_handler
'''
# def get_agent_result(model_pipe, prompt_input, departure_place, delivery_company_name):
def get_agent_result(model_pipe, prompt_input):
# action_executor, protocol_handler = init_handlers(departure_place, delivery_company_name)
inner_history = [{"role": "user", "content": prompt_input}] # NOTE TEST !!!
interpreter_executor = None
max_turn = 7
for _ in range(max_turn):
prompt = protocol_handler.format( # 生成 agent prompt
inner_step=inner_history,
plugin_executor=action_executor,
interpreter_executor=interpreter_executor,
)
cur_response = ""
agent_return = AgentReturn()
for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)):
if "~" in item.text:
item.text = item.text.replace("~", "。").replace("。。", "。")
cur_response += item.text
name, language, action = protocol_handler.parse(
message=cur_response,
plugin_executor=action_executor,
interpreter_executor=interpreter_executor,
)
if name: # "plugin"
if name == "plugin":
if action_executor:
executor = action_executor
else:
logging.info(msg="No plugin is instantiated!")
continue
try:
action = json.loads(action)
except Exception as e:
logging.info(msg=f"Invaild action {e}")
continue
elif name == "interpreter":
if interpreter_executor:
executor = interpreter_executor
else:
logging.info(msg="No interpreter is instantiated!")
continue
agent_return.response = action
print(f"Agent response: {cur_response}")
if name:
print(f"Agent action: {action}")
action_return: ActionReturn = executor(action["name"], action["parameters"])
try:
return_str = action_return.result[0]["content"]
return return_str
except Exception as e:
return ""
if not name:
agent_return.response = language
break
return ""
def get_turbomind_response(
prompt,
meta_instruction,
user_avator,
robot_avator,
model_pipe,
session_messages,
add_session_msg=True,
first_input_str="",
rag_retriever=None,
product_name="",
enable_agent=True,
# departure_place=None,
# delivery_company_name=None,
):
# ====================== Agent ======================
agent_response = ""
if enable_agent:
GENERATE_AGENT_TEMPLATE = (
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # RAG prompt 模板
)
# agent_response = get_agent_result(model_pipe, prompt, departure_place, delivery_company_name)
agent_response = get_agent_result(model_pipe, prompt)
if agent_response != "":
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, prompt)
print(f"Agent response: {agent_response}")
prompt_pro = agent_response
# ====================== RAG ======================
if rag_retriever is not None and prompt_pro == "":
# 如果 Agent 没有执行,则使用 RAG 查询数据库
prompt_pro = build_rag_prompt(rag_retriever, product_name, prompt)
# ====================== 加上历史信息 ======================
real_prompt = combine_history(
prompt_pro if prompt_pro != "" else prompt,
meta_instruction,
history_msg=session_messages,
first_input_str=first_input_str,
) # 是否加上历史对话记录
print(real_prompt)
# Add user message to chat history
if add_session_msg:
session_messages.append({"role": "user", "content": prompt, "avatar": user_avator})
with st.chat_message("assistant", avatar=robot_avator):
message_placeholder = st.empty()
cur_response = ""
for item in model_pipe.stream_infer(real_prompt, gen_config=prepare_generation_config()):
if "~" in item.text:
item.text = item.text.replace("~", "。").replace("。。", "。")
cur_response += item.text
message_placeholder.markdown(cur_response + "▌")
message_placeholder.markdown(cur_response)
tts_save_path = gen_tts_in_spinner(cur_response) # 一整句生成
gen_digital_human_video_in_spinner(tts_save_path)
# Add robot response to chat history
session_messages.append(
{
"role": "assistant",
"content": cur_response, # pylint: disable=undefined-loop-variable
"avatar": robot_avator,
"wav": tts_save_path,
}
)
torch.cuda.empty_cache()