AllenYkl commited on
Commit
119d590
·
1 Parent(s): 0057566

Rename bin_public/utils/tools.py to bin_public/utils/utils.py

Browse files
bin_public/utils/{tools.py → utils.py} RENAMED
@@ -5,31 +5,38 @@ import logging
5
  import json
6
  import os
7
  import datetime
8
- import hashlib
9
  import csv
 
 
 
 
 
10
 
11
- import gradio as gr
12
  from pypinyin import lazy_pinyin
13
  import tiktoken
 
 
 
 
 
14
 
15
  from bin_public.config.presets import *
 
16
 
17
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
 
 
 
18
 
19
  if TYPE_CHECKING:
20
  from typing import TypedDict
21
 
 
22
  class DataframeData(TypedDict):
23
  headers: List[str]
24
  data: List[List[str | int | bool]]
25
 
26
 
27
- initial_prompt = "You are a helpful assistant."
28
- API_URL = "https://api.openai.com/v1/chat/completions"
29
- HISTORY_DIR = "history"
30
- TEMPLATES_DIR = "templates"
31
-
32
-
33
  def count_token(message):
34
  encoding = tiktoken.get_encoding("cl100k_base")
35
  input_str = f"role: {message['role']}, content: {message['content']}"
@@ -37,36 +44,99 @@ def count_token(message):
37
  return length
38
 
39
 
40
- def parse_text(text):
41
- lines = text.split("\n")
42
- lines = [line for line in lines if line != ""]
43
- count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  for i, line in enumerate(lines):
45
- if "```" in line:
46
- count += 1
47
- items = line.split('`')
48
- if count % 2 == 1:
49
- lines[i] = f'<pre><code class="language-{items[-1]}">'
50
- else:
51
- lines[i] = f'<br></code></pre>'
 
 
 
 
52
  else:
53
- if i > 0:
54
- if count % 2 == 1:
55
- line = line.replace("`", "\`")
56
- line = line.replace("<", "&lt;")
57
- line = line.replace(">", "&gt;")
58
- line = line.replace(" ", "&nbsp;")
59
- line = line.replace("*", "&ast;")
60
- line = line.replace("_", "&lowbar;")
61
- line = line.replace("-", "&#45;")
62
- line = line.replace(".", "&#46;")
63
- line = line.replace("!", "&#33;")
64
- line = line.replace("(", "&#40;")
65
- line = line.replace(")", "&#41;")
66
- line = line.replace("$", "&#36;")
67
- lines[i] = "<br>" + line
68
- text = "".join(lines)
69
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  def construct_text(role, text):
@@ -89,6 +159,17 @@ def construct_token_message(token, stream=False):
89
  return f"Token 计数: {token}"
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
92
  def delete_last_conversation(chatbot, history, previous_token_count):
93
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
94
  logging.info("由于包含报错信息,只删除chatbot记录")
@@ -210,7 +291,7 @@ def load_template(filename, mode=0):
210
  lines = [[i["act"], i["prompt"]] for i in lines]
211
  else:
212
  with open(
213
- os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
214
  ) as csvfile:
215
  reader = csv.reader(csvfile)
216
  lines = list(reader)
@@ -245,20 +326,19 @@ def reset_state():
245
 
246
 
247
  def reset_textbox():
 
248
  return gr.update(value="")
249
 
250
 
251
  def reset_default():
252
- global API_URL
253
- API_URL = "https://api.openai.com/v1/chat/completions"
254
  os.environ.pop("HTTPS_PROXY", None)
255
  os.environ.pop("https_proxy", None)
256
- return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
257
 
258
 
259
  def change_api_url(url):
260
- global API_URL
261
- API_URL = url
262
  msg = f"API地址更改为了{url}"
263
  logging.info(msg)
264
  return msg
@@ -288,12 +368,138 @@ def submit_key(key):
288
  return key, msg
289
 
290
 
291
- def sha1sum(filename):
292
- sha1 = hashlib.sha1()
293
- sha1.update(filename.encode("utf-8"))
294
- return sha1.hexdigest()
295
-
296
-
297
  def replace_today(prompt):
298
  today = datetime.datetime.today().strftime("%Y-%m-%d")
299
- return prompt.replace("{current_date}", today)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import json
6
  import os
7
  import datetime
 
8
  import csv
9
+ import requests
10
+ import re
11
+ import html
12
+ import sys
13
+ import subprocess
14
 
 
15
  from pypinyin import lazy_pinyin
16
  import tiktoken
17
+ import mdtex2html
18
+ from markdown import markdown
19
+ from pygments import highlight
20
+ from pygments.lexers import get_lexer_by_name
21
+ from pygments.formatters import HtmlFormatter
22
 
23
  from bin_public.config.presets import *
24
+ import bin_public.utils.shared as shared
25
 
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
29
+ )
30
 
31
  if TYPE_CHECKING:
32
  from typing import TypedDict
33
 
34
+
35
  class DataframeData(TypedDict):
36
  headers: List[str]
37
  data: List[List[str | int | bool]]
38
 
39
 
 
 
 
 
 
 
40
  def count_token(message):
41
  encoding = tiktoken.get_encoding("cl100k_base")
42
  input_str = f"role: {message['role']}, content: {message['content']}"
 
44
  return length
45
 
46
 
47
+ def markdown_to_html_with_syntax_highlight(md_str):
48
+ def replacer(match):
49
+ lang = match.group(1) or "text"
50
+ code = match.group(2)
51
+
52
+ try:
53
+ lexer = get_lexer_by_name(lang, stripall=True)
54
+ except ValueError:
55
+ lexer = get_lexer_by_name("text", stripall=True)
56
+
57
+ formatter = HtmlFormatter()
58
+ highlighted_code = highlight(code, lexer, formatter)
59
+
60
+ return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
61
+
62
+ code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
63
+ md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
64
+
65
+ html_str = markdown(md_str)
66
+ return html_str
67
+
68
+
69
+ def normalize_markdown(md_text: str) -> str:
70
+ lines = md_text.split("\n")
71
+ normalized_lines = []
72
+ inside_list = False
73
+
74
  for i, line in enumerate(lines):
75
+ if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
76
+ if not inside_list and i > 0 and lines[i - 1].strip() != "":
77
+ normalized_lines.append("")
78
+ inside_list = True
79
+ normalized_lines.append(line)
80
+ elif inside_list and line.strip() == "":
81
+ if i < len(lines) - 1 and not re.match(
82
+ r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
83
+ ):
84
+ normalized_lines.append(line)
85
+ continue
86
  else:
87
+ inside_list = False
88
+ normalized_lines.append(line)
89
+
90
+ return "\n".join(normalized_lines)
91
+
92
+
93
+ def convert_mdtext(md_text):
94
+ code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
95
+ inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
96
+ code_blocks = code_block_pattern.findall(md_text)
97
+ non_code_parts = code_block_pattern.split(md_text)[::2]
98
+
99
+ result = []
100
+ for non_code, code in zip(non_code_parts, code_blocks + [""]):
101
+ if non_code.strip():
102
+ non_code = normalize_markdown(non_code)
103
+ if inline_code_pattern.search(non_code):
104
+ result.append(markdown(non_code, extensions=["tables"]))
105
+ else:
106
+ result.append(mdtex2html.convert(non_code, extensions=["tables"]))
107
+ if code.strip():
108
+ # _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
109
+ # code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
110
+ code = f"\n```{code}\n\n```"
111
+ code = markdown_to_html_with_syntax_highlight(code)
112
+ result.append(code)
113
+ result = "".join(result)
114
+ result += ALREADY_CONVERTED_MARK
115
+ return result
116
+
117
+
118
+ def convert_asis(userinput):
119
+ return (
120
+ f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
121
+ + ALREADY_CONVERTED_MARK
122
+ )
123
+
124
+
125
+ def detect_converted_mark(userinput):
126
+ if userinput.endswith(ALREADY_CONVERTED_MARK):
127
+ return True
128
+ else:
129
+ return False
130
+
131
+
132
+ def detect_language(code):
133
+ if code.startswith("\n"):
134
+ first_line = ""
135
+ else:
136
+ first_line = code.strip().split("\n", 1)[0]
137
+ language = first_line.lower() if first_line else ""
138
+ code_without_language = code[len(first_line):].lstrip() if first_line else code
139
+ return language, code_without_language
140
 
141
 
142
  def construct_text(role, text):
 
159
  return f"Token 计数: {token}"
160
 
161
 
162
+ def delete_first_conversation(history, previous_token_count):
163
+ if history:
164
+ del history[:2]
165
+ del previous_token_count[0]
166
+ return (
167
+ history,
168
+ previous_token_count,
169
+ construct_token_message(sum(previous_token_count)),
170
+ )
171
+
172
+
173
  def delete_last_conversation(chatbot, history, previous_token_count):
174
  if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
175
  logging.info("由于包含报错信息,只删除chatbot记录")
 
291
  lines = [[i["act"], i["prompt"]] for i in lines]
292
  else:
293
  with open(
294
+ os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
295
  ) as csvfile:
296
  reader = csv.reader(csvfile)
297
  lines = list(reader)
 
326
 
327
 
328
  def reset_textbox():
329
+ logging.debug("重置文本框")
330
  return gr.update(value="")
331
 
332
 
333
  def reset_default():
334
+ newurl = shared.state.reset_api_url()
 
335
  os.environ.pop("HTTPS_PROXY", None)
336
  os.environ.pop("https_proxy", None)
337
+ return gr.update(value=newurl), gr.update(value=""), "API URL 和代理已重置"
338
 
339
 
340
  def change_api_url(url):
341
+ shared.state.set_api_url(url)
 
342
  msg = f"API地址更改为了{url}"
343
  logging.info(msg)
344
  return msg
 
368
  return key, msg
369
 
370
 
 
 
 
 
 
 
371
  def replace_today(prompt):
372
  today = datetime.datetime.today().strftime("%Y-%m-%d")
373
+ return prompt.replace("{current_date}", today)
374
+
375
+
376
+ def get_geoip():
377
+ response = requests.get("https://ipapi.co/json/", timeout=5)
378
+ try:
379
+ data = response.json()
380
+ except:
381
+ data = {"error": True, "reason": "连接ipapi失败"}
382
+ if "error" in data.keys():
383
+ logging.warning(f"无法获取IP地址信息。\n{data}")
384
+ if data["reason"] == "RateLimited":
385
+ return (
386
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
387
+ )
388
+ else:
389
+ return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
390
+ else:
391
+ country = data["country_name"]
392
+ if country == "China":
393
+ text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
394
+ else:
395
+ text = f"您的IP区域:{country}。"
396
+ logging.info(text)
397
+ return text
398
+
399
+
400
+ def find_n(lst, max_num):
401
+ n = len(lst)
402
+ total = sum(lst)
403
+
404
+ if total < max_num:
405
+ return n
406
+
407
+ for i in range(len(lst)):
408
+ if total - lst[i] < max_num:
409
+ return n - i - 1
410
+ total = total - lst[i]
411
+ return 1
412
+
413
+
414
+ def start_outputing():
415
+ logging.debug("显示取消按钮,隐藏发送按钮")
416
+ return gr.Button.update(visible=False), gr.Button.update(visible=True)
417
+
418
+
419
+ def end_outputing():
420
+ return (
421
+ gr.Button.update(visible=True),
422
+ gr.Button.update(visible=False),
423
+ )
424
+
425
+
426
+ def cancel_outputing():
427
+ logging.info("中止输出……")
428
+ shared.state.interrupt()
429
+
430
+
431
+ def transfer_input(inputs):
432
+ # 一次性返回,降低延迟
433
+ textbox = reset_textbox()
434
+ outputing = start_outputing()
435
+ return (
436
+ inputs,
437
+ gr.update(value=""),
438
+ gr.Button.update(visible=False),
439
+ gr.Button.update(visible=True),
440
+ )
441
+
442
+
443
+ def get_proxies():
444
+ # 获取环境变量中的代理设置
445
+ http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
446
+ https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
447
+
448
+ # 如果存在代理设置,使用它们
449
+ proxies = {}
450
+ if http_proxy:
451
+ logging.info(f"使用 HTTP 代理: {http_proxy}")
452
+ proxies["http"] = http_proxy
453
+ if https_proxy:
454
+ logging.info(f"使用 HTTPS 代理: {https_proxy}")
455
+ proxies["https"] = https_proxy
456
+
457
+ if proxies == {}:
458
+ proxies = None
459
+
460
+ return proxies
461
+
462
+
463
+ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
464
+ if desc is not None:
465
+ print(desc)
466
+ if live:
467
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
468
+ if result.returncode != 0:
469
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
470
+ Command: {command}
471
+ Error code: {result.returncode}""")
472
+
473
+ return ""
474
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True,
475
+ env=os.environ if custom_env is None else custom_env)
476
+ if result.returncode != 0:
477
+ message = f"""{errdesc or 'Error running command'}.
478
+ Command: {command}
479
+ Error code: {result.returncode}
480
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout) > 0 else '<empty>'}
481
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr) > 0 else '<empty>'}
482
+ """
483
+ raise RuntimeError(message)
484
+ return result.stdout.decode(encoding="utf8", errors="ignore")
485
+
486
+
487
+ def versions_html():
488
+ git = os.environ.get('GIT', "git")
489
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
490
+ try:
491
+ commit_hash = run(f"{git} rev-parse HEAD").strip()
492
+ except Exception:
493
+ commit_hash = "<none>"
494
+ if commit_hash != "<none>":
495
+ short_commit = commit_hash[0:7]
496
+ commit_info = f"<a style=\"text-decoration:none\" href=\"https://github.com/GaiZhenbiao/ChuanhuChatGPT/commit/{short_commit}\">{short_commit}</a>"
497
+ else:
498
+ commit_info = "unknown \U0001F615"
499
+ return f"""
500
+ Python: <span title="{sys.version}">{python_version}</span>
501
+  • 
502
+ Gradio: {gr.__version__}
503
+  • 
504
+ Commit: {commit_info}
505
+ """