|
import json |
|
import logging |
|
|
|
import streamlit as st |
|
import torch |
|
from lagent.actions import ActionExecutor |
|
from lagent.agents.internlm2_agent import Internlm2Protocol |
|
from lagent.schema import ActionReturn, AgentReturn |
|
from lmdeploy import GenerationConfig |
|
|
|
from utils.digital_human.digital_human_worker import gen_digital_human_video_in_spinner |
|
from utils.rag.rag_worker import build_rag_prompt |
|
from utils.tts.tts_worker import gen_tts_in_spinner |
|
|
|
|
|
|
|
def prepare_generation_config(skip_special_tokens=True): |
|
|
|
gen_config = GenerationConfig( |
|
top_p=0.8, |
|
temperature=0.7, |
|
repetition_penalty=1.005, |
|
skip_special_tokens=skip_special_tokens, |
|
) |
|
return gen_config |
|
|
|
|
|
def combine_history(prompt, meta_instruction, history_msg=None, first_input_str=""): |
|
total_prompt = [{"role": "system", "content": meta_instruction}] |
|
|
|
if first_input_str != "": |
|
total_prompt.append({"role": "user", "content": first_input_str}) |
|
|
|
if history_msg is not None: |
|
for message in history_msg: |
|
total_prompt.append({"role": message["role"], "content": message["content"]}) |
|
|
|
total_prompt.append({"role": "user", "content": prompt}) |
|
return [total_prompt] |
|
|
|
''' |
|
@st.cache_resource |
|
# def init_handlers(departure_place, delivery_company_name): |
|
def init_handlers(): |
|
# from utils.agent.delivery_time_query import DeliveryTimeQueryAction # isort:skip |
|
|
|
META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用" |
|
|
|
INTERPRETER_CN = ( |
|
"你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。" |
|
"当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。" |
|
"这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制)," |
|
"复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性)," |
|
"文本处理和分析(比如文本解析和自然语言处理)," |
|
"机器学习和数据科学(用于展示模型训练和数据可视化)," |
|
"以及文件操作和数据导入(处理CSV、JSON等格式的文件)。" |
|
) |
|
|
|
PLUGIN_CN = ( |
|
"你可以使用如下工具:" |
|
"\n{prompt}\n" |
|
"如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! " |
|
"同时注意你可以使用的工具,不要随意捏造!" |
|
) |
|
|
|
protocol_handler = Internlm2Protocol( |
|
meta_prompt=META_CN, |
|
interpreter_prompt=INTERPRETER_CN, |
|
plugin_prompt=PLUGIN_CN, |
|
tool=dict( |
|
begin="{start_token}{name}\n", |
|
start_token="<|action_start|>", |
|
name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"), |
|
belong="assistant", |
|
end="<|action_end|>\n", |
|
), |
|
) |
|
|
|
action_list = [ |
|
DeliveryTimeQueryAction( |
|
departure_place=departure_place, |
|
delivery_company_name=delivery_company_name, |
|
), |
|
] |
|
|
|
# plugin_map = {action.name: action for action in action_list} |
|
# plugin_name = [action.name for action in action_list] |
|
# plugin_action = [plugin_map[name] for name in plugin_name] |
|
# action_executor = ActionExecutor(actions=plugin_action) |
|
|
|
# return action_executor, protocol_handler |
|
''' |
|
|
|
|
|
def get_agent_result(model_pipe, prompt_input): |
|
|
|
|
|
|
|
inner_history = [{"role": "user", "content": prompt_input}] |
|
interpreter_executor = None |
|
max_turn = 7 |
|
for _ in range(max_turn): |
|
|
|
prompt = protocol_handler.format( |
|
inner_step=inner_history, |
|
plugin_executor=action_executor, |
|
interpreter_executor=interpreter_executor, |
|
) |
|
cur_response = "" |
|
|
|
agent_return = AgentReturn() |
|
for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)): |
|
if "~" in item.text: |
|
item.text = item.text.replace("~", "。").replace("。。", "。") |
|
|
|
cur_response += item.text |
|
|
|
name, language, action = protocol_handler.parse( |
|
message=cur_response, |
|
plugin_executor=action_executor, |
|
interpreter_executor=interpreter_executor, |
|
) |
|
if name: |
|
if name == "plugin": |
|
if action_executor: |
|
executor = action_executor |
|
else: |
|
logging.info(msg="No plugin is instantiated!") |
|
continue |
|
try: |
|
action = json.loads(action) |
|
except Exception as e: |
|
logging.info(msg=f"Invaild action {e}") |
|
continue |
|
elif name == "interpreter": |
|
if interpreter_executor: |
|
executor = interpreter_executor |
|
else: |
|
logging.info(msg="No interpreter is instantiated!") |
|
continue |
|
agent_return.response = action |
|
|
|
print(f"Agent response: {cur_response}") |
|
|
|
if name: |
|
print(f"Agent action: {action}") |
|
action_return: ActionReturn = executor(action["name"], action["parameters"]) |
|
|
|
try: |
|
return_str = action_return.result[0]["content"] |
|
return return_str |
|
except Exception as e: |
|
return "" |
|
|
|
if not name: |
|
agent_return.response = language |
|
break |
|
|
|
return "" |
|
|
|
|
|
def get_turbomind_response( |
|
prompt, |
|
meta_instruction, |
|
user_avator, |
|
robot_avator, |
|
model_pipe, |
|
session_messages, |
|
add_session_msg=True, |
|
first_input_str="", |
|
rag_retriever=None, |
|
product_name="", |
|
enable_agent=True, |
|
|
|
|
|
): |
|
|
|
|
|
agent_response = "" |
|
if enable_agent: |
|
GENERATE_AGENT_TEMPLATE = ( |
|
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" |
|
) |
|
|
|
agent_response = get_agent_result(model_pipe, prompt) |
|
if agent_response != "": |
|
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, prompt) |
|
print(f"Agent response: {agent_response}") |
|
prompt_pro = agent_response |
|
|
|
|
|
if rag_retriever is not None and prompt_pro == "": |
|
|
|
prompt_pro = build_rag_prompt(rag_retriever, product_name, prompt) |
|
|
|
|
|
real_prompt = combine_history( |
|
prompt_pro if prompt_pro != "" else prompt, |
|
meta_instruction, |
|
history_msg=session_messages, |
|
first_input_str=first_input_str, |
|
) |
|
|
|
print(real_prompt) |
|
|
|
|
|
if add_session_msg: |
|
session_messages.append({"role": "user", "content": prompt, "avatar": user_avator}) |
|
|
|
with st.chat_message("assistant", avatar=robot_avator): |
|
message_placeholder = st.empty() |
|
cur_response = "" |
|
for item in model_pipe.stream_infer(real_prompt, gen_config=prepare_generation_config()): |
|
|
|
if "~" in item.text: |
|
item.text = item.text.replace("~", "。").replace("。。", "。") |
|
|
|
cur_response += item.text |
|
message_placeholder.markdown(cur_response + "▌") |
|
message_placeholder.markdown(cur_response) |
|
|
|
tts_save_path = gen_tts_in_spinner(cur_response) |
|
gen_digital_human_video_in_spinner(tts_save_path) |
|
|
|
|
|
session_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": cur_response, |
|
"avatar": robot_avator, |
|
"wav": tts_save_path, |
|
} |
|
) |
|
torch.cuda.empty_cache() |
|
|