limafang
上传utils
b40a4c8
raw
history blame
8.96 kB
import base64
import hmac
import json
from datetime import datetime, timezone
from urllib.parse import urlencode, urlparse
from websocket import create_connection, WebSocketConnectionClosedException
from utils.tools import get_prompt, process_response, init_script, create_script
class SparkAPI:
__api_url = 'wss://spark-api.xf-yun.com/v1.1/chat'
__max_token = 4096
def __init__(self, app_id, api_key, api_secret):
self.__app_id = app_id
self.__api_key = api_key
self.__api_secret = api_secret
def __set_max_tokens(self, token):
if isinstance(token, int) is False or token < 0:
print("set_max_tokens() error: tokens should be a positive integer!")
return
self.__max_token = token
def __get_authorization_url(self):
authorize_url = urlparse(self.__api_url)
# 1. generate data
date = datetime.now(timezone.utc).strftime('%a, %d %b %Y %H:%M:%S %Z')
"""
Generation rule of Authorization parameters
1) Obtain the APIKey and APISecret parameters from the console.
2) Use the aforementioned date to dynamically concatenate a string tmp. Here we take Huobi's URL as an example,
the actual usage requires replacing the host and path with the specific request URL.
"""
signature_origin = "host: {}\ndate: {}\nGET {} HTTP/1.1".format(
authorize_url.netloc, date, authorize_url.path
)
signature = base64.b64encode(
hmac.new(
self.__api_secret.encode(),
signature_origin.encode(),
digestmod='sha256'
).digest()
).decode()
authorization_origin = \
'api_key="{}",algorithm="{}",headers="{}",signature="{}"'.format(
self.__api_key, "hmac-sha256", "host date request-line", signature
)
authorization = base64.b64encode(
authorization_origin.encode()).decode()
params = {
"authorization": authorization,
"date": date,
"host": authorize_url.netloc
}
ws_url = self.__api_url + "?" + urlencode(params)
return ws_url
def __build_inputs(
self,
message: dict,
user_id: str = "001",
domain: str = "general",
temperature: float = 0.5,
max_tokens: int = 4096
):
input_dict = {
"header": {
"app_id": self.__app_id,
"uid": user_id,
},
"parameter": {
"chat": {
"domain": domain,
"temperature": temperature,
"max_tokens": max_tokens,
}
},
"payload": {
"message": message
}
}
return json.dumps(input_dict)
def chat(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.5,
):
if history is None:
history = []
# the max of max_length is 4096
max_tokens = min(max_tokens, 4096)
url = self.__get_authorization_url()
ws = create_connection(url)
message = get_prompt(query, history)
input_str = self.__build_inputs(
message=message,
user_id=user_id,
domain=domain,
temperature=temperature,
max_tokens=max_tokens,
)
ws.send(input_str)
response_str = ws.recv()
try:
while True:
response, history, status = process_response(
response_str, history)
"""
The final return result, which means a complete conversation.
doc url: https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
"""
if len(response) == 0 or status == 2:
break
response_str = ws.recv()
return response
except WebSocketConnectionClosedException:
print("Connection closed")
finally:
ws.close()
# Stream output statement, used for terminal chat.
def streaming_output(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.5,
):
if history is None:
history = []
# the max of max_length is 4096
max_tokens = min(max_tokens, 4096)
url = self.__get_authorization_url()
ws = create_connection(url)
message = get_prompt(query, history)
input_str = self.__build_inputs(
message=message,
user_id=user_id,
domain=domain,
temperature=temperature,
max_tokens=max_tokens,
)
# print(input_str)
# send question or prompt to url, and receive the answer
ws.send(input_str)
response_str = ws.recv()
# Continuous conversation
try:
while True:
response, history, status = process_response(
response_str, history)
yield response, history
if len(response) == 0 or status == 2:
break
response_str = ws.recv()
except WebSocketConnectionClosedException:
print("Connection closed")
finally:
ws.close()
def chat_stream(self):
history = []
try:
print("输入init来初始化剧本,输入create来创作剧本,输入exit或stop来终止对话\n")
while True:
query = input("Ask: ")
if query == 'init':
jsonfile = input("请输入剧本文件路径:")
script_data = init_script(history, jsonfile)
print(
f"正在导入剧本{script_data['name']},角色信息:{script_data['characters']},剧情介绍:{script_data['summary']}")
query = f"我希望你能够扮演这个剧本杀游戏的主持人,我希望你能够逐步引导玩家到达最终结局,同时希望你在游戏中设定一些随机事件,需要玩家依靠自身的能力解决,当玩家做出偏离主线的行为或者与剧本无关的行为时,你需要委婉地将玩家引导至正常游玩路线中,对于玩家需要决策的事件,你需要提供一些行动推荐,下面是剧本介绍:{script_data}"
if query == 'create':
name = input('请输入剧本名称:')
characters = input('请输入角色信息:')
summary = input('请输入剧情介绍:')
details = input('请输入剧本细节')
create_script(name, characters, summary, details)
print('剧本创建成功!')
continue
if query == "exit" or query == "stop":
break
for response, _ in self.streaming_output(query, history):
print("\r" + response, end="")
print("\n")
finally:
print("\nThank you for using the SparkDesk AI. Welcome to use it again!")
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
class Spark_forlangchain(LLM):
# 类的成员变量,类型为整型
n: int
app_id: str
api_key: str
api_secret: str
# 用于指定该子类对象的类型
@property
def _llm_type(self) -> str:
return "Spark"
# 重写基类方法,根据用户输入的prompt来响应用户,返回字符串
def _call(
self,
query: str,
history: list = None, # store the conversation history
user_id: str = "001",
domain: str = "general",
max_tokens: int = 4096,
temperature: float = 0.7,
stop: Optional[List[str]] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
bot = SparkAPI(app_id=self.app_id, api_key=self.api_key,
api_secret=self.api_secret)
response = bot.chat(query, history, user_id,
domain, max_tokens, temperature)
return response
# 返回一个字典类型,包含LLM的唯一标识
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n}