Pearx commited on
Commit
8f23e24
1 Parent(s): aeca5e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -48
app.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
  import openai
7
  from requests.models import ChunkedEncodingError
8
  from streamlit.components import v1
9
- from custom import css_code, js_code, set_context_all
10
 
11
  st.set_page_config(page_title='ChatGPT Assistant', layout='wide', page_icon='🤖')
12
  # 自定义元素样式
@@ -14,11 +13,9 @@ st.markdown(css_code, unsafe_allow_html=True)
14
 
15
  if "initial_settings" not in st.session_state:
16
  # 历史聊天窗口
17
- st.session_state["path"] = set_chats_path()
18
  st.session_state['history_chats'] = get_history_chats(st.session_state["path"])
19
  # ss参数初始化
20
- st.session_state['pre_chat'] = None
21
- st.session_state['if_chat_change'] = False
22
  st.session_state['error_info'] = ''
23
  st.session_state["current_chat_index"] = 0
24
  st.session_state['user_input_content'] = ''
@@ -36,9 +33,6 @@ with st.sidebar:
36
  key='current_chat' + st.session_state['history_chats'][st.session_state["current_chat_index"]],
37
  # on_change=current_chat_callback # 此处不适合用回调,无法识别到窗口增减的变动
38
  )
39
- if st.session_state['pre_chat'] != current_chat:
40
- st.session_state['pre_chat'] = current_chat
41
- st.session_state['if_chat_change'] = True
42
  st.write("---")
43
 
44
  c1, c2 = st.columns(2)
@@ -54,7 +48,6 @@ with st.sidebar:
54
  if len(st.session_state['history_chats']) == 1:
55
  chat_init = 'New Chat_' + str(uuid.uuid4())
56
  st.session_state['history_chats'].append(chat_init)
57
- st.session_state['current_chat'] = chat_init
58
  pre_chat_index = st.session_state['history_chats'].index(current_chat)
59
  if pre_chat_index > 0:
60
  st.session_state["current_chat_index"] = st.session_state['history_chats'].index(current_chat) - 1
@@ -73,28 +66,27 @@ with st.sidebar:
73
  st.markdown('<a href="https://github.com/PierXuY/ChatGPT-Assistant" target="_blank" rel="ChatGPT-Assistant">'
74
  '<img src="https://badgen.net/badge/icon/GitHub?icon=github&amp;label=ChatGPT Assistant" alt="GitHub">'
75
  '</a>', unsafe_allow_html=True)
 
76
  # 加载数据
77
- if ("history" + current_chat not in st.session_state) or (st.session_state['if_chat_change']):
78
  for key, value in load_data(st.session_state["path"], current_chat).items():
79
  if key == 'history':
80
  st.session_state[key + current_chat] = value
81
  else:
82
  for k, v in value.items():
83
- st.session_state[k + current_chat + 'default'] = v
84
- st.session_state['if_chat_change'] = False
85
-
86
 
87
  # 一键复制按钮
88
  st.markdown('<center><a href="https://huggingface.co/spaces/Pearx/ChatGPT-Assistant?duplicate=true">'
89
  '<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><center>', unsafe_allow_html=True)
 
90
  # 对话展示
91
  show_messages(st.session_state["history" + current_chat])
92
 
93
 
94
  # 数据写入文件
95
  def write_data(new_chat_name=current_chat):
96
- # 防止高频创建时组件尚未渲染完成,不影响正常写入
97
- if "frequency_penalty" + current_chat in st.session_state:
98
  st.session_state["paras"] = {
99
  "temperature": st.session_state["temperature" + current_chat],
100
  "top_p": st.session_state["top_p" + current_chat],
@@ -110,6 +102,13 @@ def write_data(new_chat_name=current_chat):
110
  st.session_state["paras"], st.session_state["contexts"])
111
 
112
 
 
 
 
 
 
 
 
113
  # 输入内容展示
114
  area_user_svg = st.empty()
115
  area_user_content = st.empty()
@@ -119,19 +118,27 @@ area_gpt_content = st.empty()
119
  # 报错展示
120
  area_error = st.empty()
121
 
 
 
122
  st.header('ChatGPT Assistant')
123
  tap_input, tap_context, tap_set = st.tabs(['💬 聊天', '🗒️ 预设', '⚙️ 设置'])
124
 
125
  with tap_context:
126
  set_context_list = list(set_context_all.keys())
127
- context_select_index = set_context_list.index(st.session_state['context_select' + current_chat + "default"])
128
- st.selectbox(label='选择上下文', options=set_context_list, key='context_select' + current_chat,
129
- index=context_select_index, on_change=write_data)
 
 
 
 
 
130
  st.caption(set_context_all[st.session_state['context_select' + current_chat]])
131
- context_input = st.text_area(label='补充或自定义上下文:', key="context_input" + current_chat,
132
- value=st.session_state['context_input' + current_chat + "default"],
133
- on_change=write_data)
134
- st.caption(context_input)
 
135
 
136
  with tap_set:
137
  def clear_button_callback():
@@ -139,7 +146,17 @@ with tap_set:
139
  write_data()
140
 
141
 
142
- st.button("清空聊天记录", use_container_width=True, on_click=clear_button_callback)
 
 
 
 
 
 
 
 
 
 
143
 
144
  st.markdown("OpenAI API Key (可选)")
145
  st.text_input("OpenAI API Key (可选)", type='password', key='apikey_input', label_visibility='collapsed')
@@ -147,27 +164,30 @@ with tap_set:
147
  "此Key仅在当前网页有效,且优先级高于Secrets中的配置,仅自己可用,他人无法共享。[官网获取](https://platform.openai.com/account/api-keys)")
148
 
149
  st.markdown("包含对话次数:")
150
- st.slider("Context Level", 0, 10, st.session_state['context_level' + current_chat + "default"], 1,
151
- on_change=write_data,
152
- key='context_level' + current_chat, help="表示每次会话中包含的历史对话次数,预设内容不计算在内。")
 
 
 
153
 
154
  st.markdown("模型参数:")
155
- st.slider("Temperature", 0.0, 2.0, st.session_state["temperature" + current_chat + "default"], 0.1,
156
  help="""在0和2之间,应该使用什么样的采样温度?较高的值(如0.8)会使输出更随机,而较低的值(如0.2)则会使其更加集中和确定性。
157
- 我们一般建议只更改这个参数或top_p参数中的一个,而不要同时更改两个。""",
158
- on_change=write_data, key='temperature' + current_chat)
159
- st.slider("Top P", 0.1, 1.0, st.session_state["top_p" + current_chat + "default"], 0.1,
160
  help="""一种替代采用温度进行采样的方法,称为“基于核心概率”的采样。在该方法中,模型会考虑概率最高的top_p个标记的预测结果。
161
- 因此,当该参数为0.1时,只有包括前10%概率质量的标记将被考虑。我们一般建议只更改这个参数或采样温度参数中的一个,而不要同时更改两个。""",
162
- on_change=write_data, key='top_p' + current_chat)
163
  st.slider("Presence Penalty", -2.0, 2.0,
164
- st.session_state["presence_penalty" + current_chat + "default"], 0.1,
165
  help="""该参数的取值范围为-2.0到2.0。正值会根据新标记是否出现在当前生成的文本中对其进行惩罚,从而增加模型谈论新话题的可能性。""",
166
- on_change=write_data, key='presence_penalty' + current_chat)
167
  st.slider("Frequency Penalty", -2.0, 2.0,
168
- st.session_state["frequency_penalty" + current_chat + "default"], 0.1,
169
  help="""该参数的取值范围为-2.0到2.0。正值会根据新标记在当前生成的文本中的已有频率对其进行惩罚,从而减少模型直接重复相同语句的可能性。""",
170
- on_change=write_data, key='frequency_penalty' + current_chat)
171
  st.caption("[官网参数说明](https://platform.openai.com/docs/api-reference/completions/create)")
172
 
173
  with tap_input:
@@ -177,14 +197,18 @@ with tap_input:
177
  user_input_content = st.session_state['user_input_area']
178
  df_history = pd.DataFrame(st.session_state["history" + current_chat])
179
  if len(df_history.query('role!="system"')) == 0:
180
- remove_data(st.session_state["path"], current_chat)
181
  current_chat_index = st.session_state['history_chats'].index(current_chat)
182
  new_name = extract_chars(user_input_content, 18) + '_' + str(uuid.uuid4())
 
183
  st.session_state['history_chats'][current_chat_index] = new_name
184
  st.session_state["current_chat_index"] = current_chat_index
185
  # 写入新文件
186
  write_data(new_name)
187
-
 
 
 
 
188
 
189
 
190
  with st.form("input_form", clear_on_submit=True):
@@ -197,19 +221,22 @@ with tap_input:
197
  if 'r' in st.session_state:
198
  st.session_state.pop("r")
199
  st.session_state[current_chat + 'report'] = ""
200
- st.session_state['pre_user_input_content'] = (remove_hashtag_right__space(st.session_state['user_input_content']
201
- .replace('\n', '\n\n')))
 
202
  st.session_state['user_input_content'] = ''
 
203
  show_each_message(st.session_state['pre_user_input_content'], 'user',
204
  [area_user_svg.markdown, area_user_content.markdown])
 
205
  context_level_tem = st.session_state['context_level' + current_chat]
206
- history_tem = get_history_input(st.session_state["history" + current_chat], context_level_tem) + \
207
- [{"role": "user", "content": st.session_state['pre_user_input_content']}]
208
- history_need_input = ([{"role": "system",
209
- "content": set_context_all[st.session_state['context_select' + current_chat]]}]
210
- + [{"role": "system",
211
- "content": st.session_state['context_input' + current_chat]}]
212
- + history_tem)
213
  paras_need_input = {
214
  "temperature": st.session_state["temperature" + current_chat],
215
  "top_p": st.session_state["top_p" + current_chat],
@@ -222,7 +249,7 @@ with tap_input:
222
  openai.api_key = apikey
223
  else:
224
  openai.api_key = st.secrets["apikey"]
225
- r = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=history_need_input, stream=True,
226
  **paras_need_input)
227
  except (FileNotFoundError, KeyError):
228
  area_error.error("缺失 OpenAI API Key,请在复制项目后配置Secrets,或者在设置中进行临时配置。"
@@ -261,7 +288,7 @@ if ("r" in st.session_state) and (current_chat == st.session_state["chat_of_r"])
261
  st.session_state["history" + current_chat].append(
262
  {"role": "user", "content": st.session_state['pre_user_input_content']})
263
  st.session_state["history" + current_chat].append(
264
- {"role": "assistant", "content": st.session_state[current_chat + 'report']})
265
  write_data()
266
 
267
  # 用户在网页点击stop时,ss某些情形下会暂时为空
@@ -269,6 +296,7 @@ if ("r" in st.session_state) and (current_chat == st.session_state["chat_of_r"])
269
  st.session_state.pop(current_chat + 'report')
270
  if 'r' in st.session_state:
271
  st.session_state.pop("r")
 
272
 
273
  # 添加事件监听
274
  v1.html(js_code, height=0)
 
6
  import openai
7
  from requests.models import ChunkedEncodingError
8
  from streamlit.components import v1
 
9
 
10
  st.set_page_config(page_title='ChatGPT Assistant', layout='wide', page_icon='🤖')
11
  # 自定义元素样式
 
13
 
14
  if "initial_settings" not in st.session_state:
15
  # 历史聊天窗口
16
+ st.session_state["path"] = 'history_chats_file'
17
  st.session_state['history_chats'] = get_history_chats(st.session_state["path"])
18
  # ss参数初始化
 
 
19
  st.session_state['error_info'] = ''
20
  st.session_state["current_chat_index"] = 0
21
  st.session_state['user_input_content'] = ''
 
33
  key='current_chat' + st.session_state['history_chats'][st.session_state["current_chat_index"]],
34
  # on_change=current_chat_callback # 此处不适合用回调,无法识别到窗口增减的变动
35
  )
 
 
 
36
  st.write("---")
37
 
38
  c1, c2 = st.columns(2)
 
48
  if len(st.session_state['history_chats']) == 1:
49
  chat_init = 'New Chat_' + str(uuid.uuid4())
50
  st.session_state['history_chats'].append(chat_init)
 
51
  pre_chat_index = st.session_state['history_chats'].index(current_chat)
52
  if pre_chat_index > 0:
53
  st.session_state["current_chat_index"] = st.session_state['history_chats'].index(current_chat) - 1
 
66
  st.markdown('<a href="https://github.com/PierXuY/ChatGPT-Assistant" target="_blank" rel="ChatGPT-Assistant">'
67
  '<img src="https://badgen.net/badge/icon/GitHub?icon=github&amp;label=ChatGPT Assistant" alt="GitHub">'
68
  '</a>', unsafe_allow_html=True)
69
+
70
  # 加载数据
71
+ if "history" + current_chat not in st.session_state:
72
  for key, value in load_data(st.session_state["path"], current_chat).items():
73
  if key == 'history':
74
  st.session_state[key + current_chat] = value
75
  else:
76
  for k, v in value.items():
77
+ st.session_state[k + current_chat + "value"] = v
 
 
78
 
79
  # 一键复制按钮
80
  st.markdown('<center><a href="https://huggingface.co/spaces/Pearx/ChatGPT-Assistant?duplicate=true">'
81
  '<img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><center>', unsafe_allow_html=True)
82
+
83
  # 对话展示
84
  show_messages(st.session_state["history" + current_chat])
85
 
86
 
87
  # 数据写入文件
88
  def write_data(new_chat_name=current_chat):
89
+ if "apikey" in st.secrets:
 
90
  st.session_state["paras"] = {
91
  "temperature": st.session_state["temperature" + current_chat],
92
  "top_p": st.session_state["top_p" + current_chat],
 
102
  st.session_state["paras"], st.session_state["contexts"])
103
 
104
 
105
+ def callback_fun(arg):
106
+ # 连续快速点击新建与删除会触发错误回调,增加判断
107
+ if ("history" + current_chat in st.session_state) and ("frequency_penalty" + current_chat in st.session_state):
108
+ write_data()
109
+ st.session_state[arg + current_chat + "value"] = st.session_state[arg + current_chat]
110
+
111
+
112
  # 输入内容展示
113
  area_user_svg = st.empty()
114
  area_user_content = st.empty()
 
118
  # 报错展示
119
  area_error = st.empty()
120
 
121
+ st.write("\n")
122
+ st.write("\n")
123
  st.header('ChatGPT Assistant')
124
  tap_input, tap_context, tap_set = st.tabs(['💬 聊天', '🗒️ 预设', '⚙️ 设置'])
125
 
126
  with tap_context:
127
  set_context_list = list(set_context_all.keys())
128
+ context_select_index = set_context_list.index(st.session_state['context_select' + current_chat + "value"])
129
+ st.selectbox(
130
+ label='选择上下文',
131
+ options=set_context_list,
132
+ key='context_select' + current_chat,
133
+ index=context_select_index,
134
+ on_change=callback_fun,
135
+ args=("context_select",))
136
  st.caption(set_context_all[st.session_state['context_select' + current_chat]])
137
+
138
+ st.text_area(
139
+ label='补充或自定义上下文:', key="context_input" + current_chat,
140
+ value=st.session_state['context_input' + current_chat + "value"],
141
+ on_change=callback_fun, args=("context_input",))
142
 
143
  with tap_set:
144
  def clear_button_callback():
 
146
  write_data()
147
 
148
 
149
+ c1, c2 = st.columns(2)
150
+ with c1:
151
+ st.button("清空聊天记录", use_container_width=True, on_click=clear_button_callback)
152
+ with c2:
153
+ btn = st.download_button(
154
+ label="导出聊天记录",
155
+ data=download_history(st.session_state['history' + current_chat]),
156
+ file_name=f'{current_chat.split("_")[0]}.md',
157
+ mime="text/markdown",
158
+ use_container_width=True
159
+ )
160
 
161
  st.markdown("OpenAI API Key (可选)")
162
  st.text_input("OpenAI API Key (可选)", type='password', key='apikey_input', label_visibility='collapsed')
 
164
  "此Key仅在当前网页有效,且优先级高于Secrets中的配置,仅自己可用,他人无法共享。[官网获取](https://platform.openai.com/account/api-keys)")
165
 
166
  st.markdown("包含对话次数:")
167
+ st.slider(
168
+ "Context Level", 0, 10,
169
+ st.session_state['context_level' + current_chat + "value"], 1,
170
+ on_change=callback_fun,
171
+ key='context_level' + current_chat, args=('context_level',),
172
+ help="表示每次会话中包含的历史对话次数,预设内容不计算在内。")
173
 
174
  st.markdown("模型参数:")
175
+ st.slider("Temperature", 0.0, 2.0, st.session_state["temperature" + current_chat + "value"], 0.1,
176
  help="""在0和2之间,应该使用什么样的采样温度?较高的值(如0.8)会使输出更随机,而较低的值(如0.2)则会使其更加集中和确定性。
177
+ 我们一般建议只更改这个参数或top_p参数中的一个,而不要同时更改两个。""",
178
+ on_change=callback_fun, key='temperature' + current_chat, args=('temperature',))
179
+ st.slider("Top P", 0.1, 1.0, st.session_state["top_p" + current_chat + "value"], 0.1,
180
  help="""一种替代采用温度进行采样的方法,称为“基于核心概率”的采样。在该方法中,模型会考虑概率最高的top_p个标记的预测结果。
181
+ 因此,当该参数为0.1时,只有包括前10%概率质量的标记将被考虑。我们一般建议只更改这个参数或采样温度参数中的一个,而不要同时更改两个。""",
182
+ on_change=callback_fun, key='top_p' + current_chat, args=('top_p',))
183
  st.slider("Presence Penalty", -2.0, 2.0,
184
+ st.session_state["presence_penalty" + current_chat + "value"], 0.1,
185
  help="""该参数的取值范围为-2.0到2.0。正值会根据新标记是否出现在当前生成的文本中对其进行惩罚,从而增加模型谈论新话题的可能性。""",
186
+ on_change=callback_fun, key='presence_penalty' + current_chat, args=('presence_penalty',))
187
  st.slider("Frequency Penalty", -2.0, 2.0,
188
+ st.session_state["frequency_penalty" + current_chat + "value"], 0.1,
189
  help="""该参数的取值范围为-2.0到2.0。正值会根据新标记在当前生成的文本中的已有频率对其进行惩罚,从而减少模型直接重复相同语句的可能性。""",
190
+ on_change=callback_fun, key='frequency_penalty' + current_chat, args=('frequency_penalty',))
191
  st.caption("[官网参数说明](https://platform.openai.com/docs/api-reference/completions/create)")
192
 
193
  with tap_input:
 
197
  user_input_content = st.session_state['user_input_area']
198
  df_history = pd.DataFrame(st.session_state["history" + current_chat])
199
  if len(df_history.query('role!="system"')) == 0:
 
200
  current_chat_index = st.session_state['history_chats'].index(current_chat)
201
  new_name = extract_chars(user_input_content, 18) + '_' + str(uuid.uuid4())
202
+ new_name = filename_correction(new_name)
203
  st.session_state['history_chats'][current_chat_index] = new_name
204
  st.session_state["current_chat_index"] = current_chat_index
205
  # 写入新文件
206
  write_data(new_name)
207
+ # 转移数据
208
+ st.session_state['history' + new_name] = st.session_state['history' + current_chat]
209
+ for item in ["context_select", "context_input", "context_level", *initial_content_all['paras']]:
210
+ st.session_state[item + new_name + "value"] = st.session_state[item + current_chat + "value"]
211
+ remove_data(st.session_state["path"], current_chat)
212
 
213
 
214
  with st.form("input_form", clear_on_submit=True):
 
221
  if 'r' in st.session_state:
222
  st.session_state.pop("r")
223
  st.session_state[current_chat + 'report'] = ""
224
+ st.session_state['pre_user_input_content'] = url_correction(
225
+ remove_hashtag_right__space(st.session_state['user_input_content']
226
+ .replace('\n', '\n\n')))
227
  st.session_state['user_input_content'] = ''
228
+
229
  show_each_message(st.session_state['pre_user_input_content'], 'user',
230
  [area_user_svg.markdown, area_user_content.markdown])
231
+
232
  context_level_tem = st.session_state['context_level' + current_chat]
233
+ history_need_input = (get_history_input(st.session_state["history" + current_chat], context_level_tem) +
234
+ [{"role": "user", "content": st.session_state['pre_user_input_content']}])
235
+ for ctx in [st.session_state['context_input' + current_chat],
236
+ set_context_all[st.session_state['context_select' + current_chat]]]:
237
+ if ctx != "":
238
+ history_need_input = [{"role": "system", "content": ctx}] + history_need_input
239
+
240
  paras_need_input = {
241
  "temperature": st.session_state["temperature" + current_chat],
242
  "top_p": st.session_state["top_p" + current_chat],
 
249
  openai.api_key = apikey
250
  else:
251
  openai.api_key = st.secrets["apikey"]
252
+ r = openai.ChatCompletion.create(model=model, messages=history_need_input, stream=True,
253
  **paras_need_input)
254
  except (FileNotFoundError, KeyError):
255
  area_error.error("缺失 OpenAI API Key,请在复制项目后配置Secrets,或者在设置中进行临时配置。"
 
288
  st.session_state["history" + current_chat].append(
289
  {"role": "user", "content": st.session_state['pre_user_input_content']})
290
  st.session_state["history" + current_chat].append(
291
+ {"role": "assistant", "content": url_correction(st.session_state[current_chat + 'report'])})
292
  write_data()
293
 
294
  # 用户在网页点击stop时,ss某些情形下会暂时为空
 
296
  st.session_state.pop(current_chat + 'report')
297
  if 'r' in st.session_state:
298
  st.session_state.pop("r")
299
+ st.experimental_rerun()
300
 
301
  # 添加事件监听
302
  v1.html(js_code, height=0)