JohnSmith9982's picture
Upload 98 files
0cc999a
raw
history blame
11 kB
from __future__ import annotations
import json
import logging
import traceback
import colorama
import requests
from .. import shared
from ..config import retrieve_proxy, sensitive_id, usage_limit
from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
class OpenAIClient(BaseLLMModel):
def __init__(
self,
model_name,
api_key,
system_prompt=INITIAL_SYSTEM_PROMPT,
temperature=1.0,
top_p=1.0,
user_name=""
) -> None:
super().__init__(
model_name=MODEL_METADATA[model_name]["model_name"],
temperature=temperature,
top_p=top_p,
system_prompt=system_prompt,
user=user_name
)
self.api_key = api_key
self.need_api_key = True
self._refresh_header()
def get_answer_stream_iter(self):
response = self._get_response(stream=True)
if response is not None:
iter = self._decode_chat_response(response)
partial_text = ""
for i in iter:
partial_text += i
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG
def get_answer_at_once(self):
response = self._get_response()
response = json.loads(response.text)
content = response["choices"][0]["message"]["content"]
total_token_count = response["usage"]["total_tokens"]
return content, total_token_count
def count_token(self, user_input):
input_token_count = count_token(construct_user(user_input))
if self.system_prompt is not None and len(self.all_token_counts) == 0:
system_prompt_token_count = count_token(
construct_system(self.system_prompt)
)
return input_token_count + system_prompt_token_count
return input_token_count
def billing_info(self):
try:
curr_time = datetime.datetime.now()
last_day_of_month = get_last_day_of_month(
curr_time).strftime("%Y-%m-%d")
first_day_of_month = curr_time.replace(day=1).strftime("%Y-%m-%d")
usage_url = f"{shared.state.usage_api_url}?start_date={first_day_of_month}&end_date={last_day_of_month}"
try:
usage_data = self._get_billing_data(usage_url)
except Exception as e:
# logging.error(f"获取API使用情况失败: " + str(e))
if "Invalid authorization header" in str(e):
return i18n("**获取API使用情况失败**,需在填写`config.json`中正确填写sensitive_id")
elif "Incorrect API key provided: sess" in str(e):
return i18n("**获取API使用情况失败**,sensitive_id错误或已过期")
return i18n("**获取API使用情况失败**")
# rounded_usage = "{:.5f}".format(usage_data["total_usage"] / 100)
rounded_usage = round(usage_data["total_usage"] / 100, 5)
usage_percent = round(usage_data["total_usage"] / usage_limit, 2)
from ..webui import get_html
# return i18n("**本月使用金额** ") + f"\u3000 ${rounded_usage}"
return get_html("billing_info.html").format(
label = i18n("本月使用金额"),
usage_percent = usage_percent,
rounded_usage = rounded_usage,
usage_limit = usage_limit
)
except requests.exceptions.ConnectTimeout:
status_text = (
STANDARD_ERROR_MSG + CONNECTION_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
)
return status_text
except requests.exceptions.ReadTimeout:
status_text = STANDARD_ERROR_MSG + READ_TIMEOUT_MSG + ERROR_RETRIEVE_MSG
return status_text
except Exception as e:
import traceback
traceback.print_exc()
logging.error(i18n("获取API使用情况失败:") + str(e))
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG
def set_token_upper_limit(self, new_upper_limit):
pass
@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
def _get_response(self, stream=False):
openai_api_key = self.api_key
system_prompt = self.system_prompt
history = self.history
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}",
}
if system_prompt is not None:
history = [construct_system(system_prompt), *history]
payload = {
"model": self.model_name,
"messages": history,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n_choices,
"stream": stream,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
}
if self.max_generation_token is not None:
payload["max_tokens"] = self.max_generation_token
if self.stop_sequence is not None:
payload["stop"] = self.stop_sequence
if self.logit_bias is not None:
payload["logit_bias"] = self.logit_bias
if self.user_identifier:
payload["user"] = self.user_identifier
if stream:
timeout = TIMEOUT_STREAMING
else:
timeout = TIMEOUT_ALL
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
with retrieve_proxy():
try:
response = requests.post(
shared.state.chat_completion_url,
headers=headers,
json=payload,
stream=stream,
timeout=timeout,
)
except:
traceback.print_exc()
return None
return response
def _refresh_header(self):
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {sensitive_id}",
}
def _get_billing_data(self, billing_url):
with retrieve_proxy():
response = requests.get(
billing_url,
headers=self.headers,
timeout=TIMEOUT_ALL,
)
if response.status_code == 200:
data = response.json()
return data
else:
raise Exception(
f"API request failed with status code {response.status_code}: {response.text}"
)
def _decode_chat_response(self, response):
error_msg = ""
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode()
chunk_length = len(chunk)
try:
chunk = json.loads(chunk[6:])
except:
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
error_msg += chunk
continue
try:
if chunk_length > 6 and "delta" in chunk["choices"][0]:
if "finish_reason" in chunk["choices"][0]:
finish_reason = chunk["choices"][0]["finish_reason"]
else:
finish_reason = chunk["finish_reason"]
if finish_reason == "stop":
break
try:
yield chunk["choices"][0]["delta"]["content"]
except Exception as e:
# logging.error(f"Error: {e}")
continue
except:
print(f"ERROR: {chunk}")
continue
if error_msg and not error_msg=="data: [DONE]":
raise Exception(error_msg)
def set_key(self, new_access_key):
ret = super().set_key(new_access_key)
self._refresh_header()
return ret
def _single_query_at_once(self, history, temperature=1.0):
timeout = TIMEOUT_ALL
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"temperature": f"{temperature}",
}
payload = {
"model": self.model_name,
"messages": history,
}
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")
with retrieve_proxy():
response = requests.post(
shared.state.chat_completion_url,
headers=headers,
json=payload,
stream=False,
timeout=timeout,
)
return response
def auto_name_chat_history(self, name_chat_method, user_question, chatbot, user_name, single_turn_checkbox):
if len(self.history) == 2 and not single_turn_checkbox and not hide_history_when_not_logged_in:
user_question = self.history[0]["content"]
if name_chat_method == i18n("模型自动总结(消耗tokens)"):
ai_answer = self.history[1]["content"]
try:
history = [
{ "role": "system", "content": SUMMARY_CHAT_SYSTEM_PROMPT},
{ "role": "user", "content": f"Please write a title based on the following conversation:\n---\nUser: {user_question}\nAssistant: {ai_answer}"}
]
response = self._single_query_at_once(history, temperature=0.0)
response = json.loads(response.text)
content = response["choices"][0]["message"]["content"]
filename = replace_special_symbols(content) + ".json"
except Exception as e:
logging.info(f"自动命名失败。{e}")
filename = replace_special_symbols(user_question)[:16] + ".json"
return self.rename_chat_history(filename, chatbot, user_name)
elif name_chat_method == i18n("第一条提问"):
filename = replace_special_symbols(user_question)[:16] + ".json"
return self.rename_chat_history(filename, chatbot, user_name)
else:
return gr.update()
else:
return gr.update()