import base64 import hmac import json from datetime import datetime, timezone from urllib.parse import urlencode, urlparse from websocket import create_connection, WebSocketConnectionClosedException from import get_prompt, process_response, init_script, create_script class SparkAPI: __api_url = 'wss://' __max_token = 4096 def __init__(self, app_id, api_key, api_secret): self.__app_id = app_id self.__api_key = api_key self.__api_secret = api_secret def __set_max_tokens(self, token): if isinstance(token, int) is False or token < 0: print("set_max_tokens() error: tokens should be a positive integer!") return self.__max_token = token def __get_authorization_url(self): authorize_url = urlparse(self.__api_url) # 1. generate data date ='%a, %d %b %Y %H:%M:%S %Z') """ Generation rule of Authorization parameters 1) Obtain the APIKey and APISecret parameters from the console. 2) Use the aforementioned date to dynamically concatenate a string tmp. Here we take Huobi's URL as an example, the actual usage requires replacing the host and path with the specific request URL. """ signature_origin = "host: {}\ndate: {}\nGET {} HTTP/1.1".format( authorize_url.netloc, date, authorize_url.path ) signature = base64.b64encode( self.__api_secret.encode(), signature_origin.encode(), digestmod='sha256' ).digest() ).decode() authorization_origin = \ 'api_key="{}",algorithm="{}",headers="{}",signature="{}"'.format( self.__api_key, "hmac-sha256", "host date request-line", signature ) authorization = base64.b64encode( authorization_origin.encode()).decode() params = { "authorization": authorization, "date": date, "host": authorize_url.netloc } ws_url = self.__api_url + "?" + urlencode(params) return ws_url def __build_inputs( self, message: dict, user_id: str = "001", domain: str = "general", temperature: float = 0.5, max_tokens: int = 4096 ): input_dict = { "header": { "app_id": self.__app_id, "uid": user_id, }, "parameter": { "chat": { "domain": domain, "temperature": temperature, "max_tokens": max_tokens, } }, "payload": { "message": message } } return json.dumps(input_dict) def chat( self, query: str, history: list = None, # store the conversation history user_id: str = "001", domain: str = "general", max_tokens: int = 4096, temperature: float = 0.5, ): if history is None: history = [] # the max of max_length is 4096 max_tokens = min(max_tokens, 4096) url = self.__get_authorization_url() ws = create_connection(url) message = get_prompt(query, history) input_str = self.__build_inputs( message=message, user_id=user_id, domain=domain, temperature=temperature, max_tokens=max_tokens, ) ws.send(input_str) response_str = ws.recv() try: while True: response, history, status = process_response( response_str, history) """ The final return result, which means a complete conversation. doc url: """ if len(response) == 0 or status == 2: break response_str = ws.recv() return response except WebSocketConnectionClosedException: print("Connection closed") finally: ws.close() # Stream output statement, used for terminal chat. def streaming_output( self, query: str, history: list = None, # store the conversation history user_id: str = "001", domain: str = "general", max_tokens: int = 4096, temperature: float = 0.5, ): if history is None: history = [] # the max of max_length is 4096 max_tokens = min(max_tokens, 4096) url = self.__get_authorization_url() ws = create_connection(url) message = get_prompt(query, history) input_str = self.__build_inputs( message=message, user_id=user_id, domain=domain, temperature=temperature, max_tokens=max_tokens, ) # print(input_str) # send question or prompt to url, and receive the answer ws.send(input_str) response_str = ws.recv() # Continuous conversation try: while True: response, history, status = process_response( response_str, history) yield response, history if len(response) == 0 or status == 2: break response_str = ws.recv() except WebSocketConnectionClosedException: print("Connection closed") finally: ws.close() def chat_stream(self): history = [] try: print("输入init来初始化剧本,输入create来创作剧本,输入exit或stop来终止对话\n") while True: query = input("Ask: ") if query == 'init': jsonfile = input("请输入剧本文件路径:") script_data = init_script(history, jsonfile) print( f"正在导入剧本{script_data['name']},角色信息:{script_data['characters']},剧情介绍:{script_data['summary']}") query = f"我希望你能够扮演这个剧本杀游戏的主持人,我希望你能够逐步引导玩家到达最终结局,同时希望你在游戏中设定一些随机事件,需要玩家依靠自身的能力解决,当玩家做出偏离主线的行为或者与剧本无关的行为时,你需要委婉地将玩家引导至正常游玩路线中,对于玩家需要决策的事件,你需要提供一些行动推荐,下面是剧本介绍:{script_data}" if query == 'create': name = input('请输入剧本名称:') characters = input('请输入角色信息:') summary = input('请输入剧情介绍:') details = input('请输入剧本细节') create_script(name, characters, summary, details) print('剧本创建成功!') continue if query == "exit" or query == "stop": break for response, _ in self.streaming_output(query, history): print("\r" + response, end="") print("\n") finally: print("\nThank you for using the SparkDesk AI. Welcome to use it again!") from langchain.llms.base import LLM from typing import Any, List, Mapping, Optional class Spark_forlangchain(LLM): # 类的成员变量,类型为整型 n: int app_id: str api_key: str api_secret: str # 用于指定该子类对象的类型 @property def _llm_type(self) -> str: return "Spark" # 重写基类方法,根据用户输入的prompt来响应用户,返回字符串 def _call( self, query: str, history: list = None, # store the conversation history user_id: str = "001", domain: str = "general", max_tokens: int = 4096, temperature: float = 0.7, stop: Optional[List[str]] = None, ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") bot = SparkAPI(app_id=self.app_id, api_key=self.api_key, api_secret=self.api_secret) response =, history, user_id, domain, max_tokens, temperature) return response # 返回一个字典类型,包含LLM的唯一标识 @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"n": self.n}