Pearx commited on
Commit
16d9c95
·
1 Parent(s): e9a3462

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +57 -19
helper.py CHANGED
@@ -1,8 +1,11 @@
1
  import json
2
- import streamlit as st
3
  import os
4
- from set_context import set_context
 
5
  import uuid
 
 
 
6
 
7
  set_context_all = {"不设置": ""}
8
  set_context_all.update(set_context)
@@ -19,10 +22,8 @@ gpt_svg = """
19
  """
20
  # 内容背景
21
  user_background_color = '#ffffff'
22
- gpt_background_color = '#f9fafb'
23
- # 聊天记录文件夹名称
24
- history_chats_filename = 'chat_history'
25
-
26
  initial_content_history = [{"role": 'system',
27
  "content": '当你的回复中涉及代码块时,请在markdown语法中标明语言类型。如果不涉及,请忽略这句话。'}]
28
  initial_content_all = {"history": initial_content_history,
@@ -35,18 +36,52 @@ initial_content_all = {"history": initial_content_history,
35
  "contexts": {
36
  'context_select': '不设置',
37
  'context_input': '',
38
- 'context_level': 5
39
  }
40
  }
41
 
42
 
43
- def get_history_chats(file_name=history_chats_filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  try:
45
- os.mkdir(file_name)
46
  except FileExistsError:
47
  pass
48
- files = [f for f in os.listdir(f'./{file_name}') if f.endswith('.json')]
49
- files_with_time = [(f, os.stat(f'./{file_name}/' + f).st_ctime) for f in files]
50
  sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
51
  chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
52
  if len(chat_names) == 0:
@@ -54,22 +89,25 @@ def get_history_chats(file_name=history_chats_filename):
54
  return chat_names
55
 
56
 
57
- def save_data(current_chat: str, history: list, paras: dict, contexts: dict, **kwargs):
58
- with open(f"./{history_chats_filename}/{current_chat}.json", 'w', encoding='utf-8') as f:
59
  json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)
60
 
61
 
62
- def remove_data(current_chat: str):
63
- os.remove(f"./{history_chats_filename}/{current_chat}.json")
 
 
 
64
 
65
 
66
- def load_data(current_chat: str) -> dict:
67
  try:
68
- with open(f"./{history_chats_filename}/{current_chat}.json", 'r', encoding='utf-8') as f:
69
  data = json.load(f)
70
  return data
71
  except FileNotFoundError:
72
- with open(f"./{history_chats_filename}/{current_chat}.json", 'w', encoding='utf-8') as f:
73
  f.write(json.dumps(initial_content_all))
74
  return initial_content_all
75
 
@@ -95,4 +133,4 @@ def show_messages(messages: list):
95
  if (each["role"] == "user") or (each["role"] == "assistant"):
96
  show_each_message(each["content"], each["role"])
97
  if each["role"] == "assistant":
98
- st.write("---")
 
1
  import json
 
2
  import os
3
+ import builtins
4
+ import shutil
5
  import uuid
6
+ from functools import wraps
7
+ import streamlit as st
8
+ from set_context import set_context
9
 
10
  set_context_all = {"不设置": ""}
11
  set_context_all.update(set_context)
 
22
  """
23
  # 内容背景
24
  user_background_color = '#ffffff'
25
+ gpt_background_color = '#f0f2f6'
26
+ # 模型初始设置
 
 
27
  initial_content_history = [{"role": 'system',
28
  "content": '当你的回复中涉及代码块时,请在markdown语法中标明语言类型。如果不涉及,请忽略这句话。'}]
29
  initial_content_all = {"history": initial_content_history,
 
36
  "contexts": {
37
  'context_select': '不设置',
38
  'context_input': '',
39
+ 'context_level': 4
40
  }
41
  }
42
 
43
 
44
+ # 聊天记录处理
45
+ def clear_folder(path):
46
+ if not os.path.exists(path):
47
+ return
48
+ for file_name in os.listdir(path):
49
+ file_path = os.path.join(path, file_name)
50
+ try:
51
+ shutil.rmtree(file_path)
52
+ except Exception:
53
+ pass
54
+
55
+
56
+ def set_chats_path():
57
+ save_path = 'chat_history'
58
+ if 'apikey' not in st.secrets:
59
+ clear_folder('tem_files')
60
+ save_path = 'tem_files/tem_chat' + str(uuid.uuid4())
61
+ return save_path
62
+
63
+
64
+ # 重新open函数,路径不存在时自动创建
65
+ def create_path(func):
66
+ @wraps(func)
67
+ def wrapper(path, *args, **kwargs):
68
+ if not os.path.exists(os.path.dirname(path)):
69
+ os.makedirs(os.path.dirname(path))
70
+ return func(path, *args, **kwargs)
71
+
72
+ return wrapper
73
+
74
+
75
+ open = create_path(builtins.open)
76
+
77
+
78
+ def get_history_chats(path):
79
  try:
80
+ os.makedirs(path)
81
  except FileExistsError:
82
  pass
83
+ files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
84
+ files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
85
  sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
86
  chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
87
  if len(chat_names) == 0:
 
89
  return chat_names
90
 
91
 
92
+ def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
93
+ with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
94
  json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)
95
 
96
 
97
+ def remove_data(path: str, file_name: str):
98
+ try:
99
+ os.remove(f"./{path}/{file_name}.json")
100
+ except FileNotFoundError:
101
+ pass
102
 
103
 
104
+ def load_data(path: str, file_name: str) -> dict:
105
  try:
106
+ with open(f"./{path}/{file_name}.json", 'r', encoding='utf-8') as f:
107
  data = json.load(f)
108
  return data
109
  except FileNotFoundError:
110
+ with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
111
  f.write(json.dumps(initial_content_all))
112
  return initial_content_all
113
 
 
133
  if (each["role"] == "user") or (each["role"] == "assistant"):
134
  show_each_message(each["content"], each["role"])
135
  if each["role"] == "assistant":
136
+ st.write("---")