Pearx commited on
Commit
cea8bca
1 Parent(s): 00fa5e9

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +72 -65
helper.py CHANGED
@@ -1,73 +1,48 @@
1
  import json
2
  import os
3
  import re
4
- import builtins
5
- import shutil
6
  import uuid
7
- from functools import wraps
8
  import streamlit as st
9
  import pandas as pd
10
  from custom import *
11
-
12
-
13
- # 聊天记录处理
14
- def clear_folder(path):
15
- if not os.path.exists(path):
16
- return
17
- for file_name in os.listdir(path):
18
- file_path = os.path.join(path, file_name)
19
- try:
20
- shutil.rmtree(file_path)
21
- except Exception:
22
- pass
23
-
24
-
25
- def set_chats_path():
26
- save_path = 'chat_history'
27
- if 'apikey' not in st.secrets:
28
- clear_folder('tem_files')
29
- save_path = 'tem_files/tem_chat' + str(uuid.uuid4())
30
- return save_path
31
-
32
-
33
- # 重新open函数,路径不存在时自动创建
34
- def create_path(func):
35
- @wraps(func)
36
- def wrapper(path, *args, **kwargs):
37
- if not os.path.exists(os.path.dirname(path)):
38
- os.makedirs(os.path.dirname(path))
39
- return func(path, *args, **kwargs)
40
-
41
- return wrapper
42
-
43
-
44
- open = create_path(builtins.open)
45
-
46
-
47
- def get_history_chats(path):
48
- try:
49
- os.makedirs(path)
50
- except FileExistsError:
51
- pass
52
- files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
53
- files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
54
- sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
55
- chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
56
- if len(chat_names) == 0:
57
- chat_names.append('New Chat_' + str(uuid.uuid4()))
58
  return chat_names
59
 
60
 
61
  def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
 
 
62
  with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
63
  json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)
64
 
65
 
66
- def remove_data(path: str, file_name: str):
67
  try:
68
- os.remove(f"./{path}/{file_name}.json")
69
  except FileNotFoundError:
70
  pass
 
 
 
 
 
 
 
71
 
72
 
73
  def load_data(path: str, file_name: str) -> dict:
@@ -76,12 +51,14 @@ def load_data(path: str, file_name: str) -> dict:
76
  data = json.load(f)
77
  return data
78
  except FileNotFoundError:
79
- with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
80
- f.write(json.dumps(initial_content_all))
81
- return initial_content_all
 
 
82
 
83
 
84
- def show_each_message(message, role, area=None):
85
  if area is None:
86
  area = [st.markdown] * 2
87
  if role == 'user':
@@ -106,23 +83,26 @@ def show_messages(messages: list):
106
 
107
 
108
  # 根据context_level提取history
109
- def get_history_input(history, level):
110
- df_history = pd.DataFrame(history)
111
- df_system = df_history.query('role=="system"')
112
- df_input = df_history.query('role!="system"')
113
- df_input = df_input[-level * 2:]
114
- res = pd.concat([df_system, df_input], ignore_index=True).to_dict('records')
 
 
 
115
  return res
116
 
117
 
118
  # 去除#号右边的空格
119
- def remove_hashtag_right__space(text):
120
  res = re.sub(r"(#+)\s*", r"\1", text)
121
  return res
122
 
123
 
124
  # 提取文本
125
- def extract_chars(text, num):
126
  char_num = 0
127
  chars = ''
128
  for char in text:
@@ -134,4 +114,31 @@ def extract_chars(text, num):
134
  chars += char
135
  if char_num >= num:
136
  break
137
- return chars
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
  import re
 
 
4
  import uuid
 
5
  import streamlit as st
6
  import pandas as pd
7
  from custom import *
8
+ import copy
9
+ import io
10
+
11
+
12
+ def get_history_chats(path: str) -> list:
13
+ if "apikey" in st.secrets:
14
+ if not os.path.exists(path):
15
+ os.makedirs(path)
16
+ files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
17
+ files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
18
+ sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
19
+ chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
20
+ if len(chat_names) == 0:
21
+ chat_names.append('New Chat_' + str(uuid.uuid4()))
22
+ else:
23
+ chat_names = ['New Chat_' + str(uuid.uuid4())]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  return chat_names
25
 
26
 
27
  def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
28
+ if not os.path.exists(path):
29
+ os.makedirs(path)
30
  with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
31
  json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)
32
 
33
 
34
+ def remove_data(path: str, chat_name: str):
35
  try:
36
+ os.remove(f"./{path}/{chat_name}.json")
37
  except FileNotFoundError:
38
  pass
39
+ # 清除缓存
40
+ try:
41
+ st.session_state.pop('history' + chat_name)
42
+ for item in ["context_select", "context_input", "context_level", *initial_content_all['paras']]:
43
+ st.session_state.pop(item + chat_name + "value")
44
+ except KeyError:
45
+ pass
46
 
47
 
48
  def load_data(path: str, file_name: str) -> dict:
 
51
  data = json.load(f)
52
  return data
53
  except FileNotFoundError:
54
+ content = copy.deepcopy(initial_content_all)
55
+ if "apikey" in st.secrets:
56
+ with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
57
+ f.write(json.dumps(content))
58
+ return content
59
 
60
 
61
+ def show_each_message(message: str, role: str, area=None):
62
  if area is None:
63
  area = [st.markdown] * 2
64
  if role == 'user':
 
83
 
84
 
85
  # 根据context_level提取history
86
+ def get_history_input(history: list, level: int) -> list:
87
+ if level != 0:
88
+ df_history = pd.DataFrame(history)
89
+ df_system = df_history.query('role=="system"')
90
+ df_input = df_history.query('role!="system"')
91
+ df_input = df_input[-level * 2:]
92
+ res = pd.concat([df_system, df_input], ignore_index=True).to_dict('records')
93
+ else:
94
+ res = []
95
  return res
96
 
97
 
98
  # 去除#号右边的空格
99
+ def remove_hashtag_right__space(text: str) -> str:
100
  res = re.sub(r"(#+)\s*", r"\1", text)
101
  return res
102
 
103
 
104
  # 提取文本
105
+ def extract_chars(text: str, num: int) -> str:
106
  char_num = 0
107
  chars = ''
108
  for char in text:
 
114
  chars += char
115
  if char_num >= num:
116
  break
117
+ return chars
118
+
119
+
120
+ def download_history(history: list):
121
+ md_text = ""
122
+ for msg in history:
123
+ if msg['role'] == 'user':
124
+ md_text += f'## {user_name}:\n{msg["content"]}\n'
125
+ elif msg['role'] == 'assistant':
126
+ md_text += f'## {gpt_name}:\n{msg["content"]}\n'
127
+ output = io.BytesIO()
128
+ output.write(md_text.encode('utf-8'))
129
+ output.seek(0)
130
+ return output
131
+
132
+
133
+ def filename_correction(filename: str) -> str:
134
+ pattern = r'[^\w\.-]'
135
+ filename = re.sub(pattern, '', filename)
136
+ return filename
137
+
138
+
139
+ def url_correction(text: str) -> str:
140
+ pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+#]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
141
+ links = re.findall(pattern, text)
142
+ for link in links:
143
+ text = text.replace(link, " " + link + " ")
144
+ return text