ChatGPT-Assistant / helper.py
Pearx's picture
Update helper.py
5600923
raw
history blame
4.98 kB
import json
import os
import re
import uuid
import streamlit as st
import pandas as pd
from custom import *
import copy
import io
def get_history_chats(path: str) -> list:
if "apikey" in st.secrets:
if not os.path.exists(path):
os.makedirs(path)
files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
if len(chat_names) == 0:
chat_names.append('New Chat_' + str(uuid.uuid4()))
else:
chat_names = ['New Chat_' + str(uuid.uuid4())]
return chat_names
def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
if not os.path.exists(path):
os.makedirs(path)
with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)
def remove_data(path: str, chat_name: str):
try:
os.remove(f"./{path}/{chat_name}.json")
except FileNotFoundError:
pass
# 清除缓存
try:
st.session_state.pop('history' + chat_name)
for item in ["context_select", "context_input", "context_level", *initial_content_all['paras']]:
st.session_state.pop(item + chat_name + "value")
except KeyError:
pass
def load_data(path: str, file_name: str) -> dict:
try:
with open(f"./{path}/{file_name}.json", 'r', encoding='utf-8') as f:
data = json.load(f)
return data
except FileNotFoundError:
content = copy.deepcopy(initial_content_all)
if "apikey" in st.secrets:
with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
f.write(json.dumps(content))
return content
def show_each_message(message: str, role: str, area=None):
if area is None:
area = [st.markdown] * 2
if role == 'user':
icon = user_svg
name = user_name
background_color = user_background_color
else:
icon = gpt_svg
name = gpt_name
background_color = gpt_background_color
message = colon_correction(
url_correction(message)
)
area[0](f"\n<div class='avatar'>{icon}<h2>{name}:</h2></div>", unsafe_allow_html=True)
area[1](f"""<div class='content-div' style='background-color: {background_color};'>\n\n{message}""",
unsafe_allow_html=True)
def show_messages(messages: list):
for each in messages:
if (each["role"] == "user") or (each["role"] == "assistant"):
show_each_message(each["content"], each["role"])
if each["role"] == "assistant":
st.write("---")
# 根据context_level提取history
def get_history_input(history: list, level: int) -> list:
if level != 0:
df_history = pd.DataFrame(history)
df_system = df_history.query('role=="system"')
df_input = df_history.query('role!="system"')
df_input = df_input[-level * 2:]
res = pd.concat([df_system, df_input], ignore_index=True).to_dict('records')
else:
res = []
return res
# 去除#号右边的空格
# def remove_hashtag_right__space(text: str) -> str:
# text = re.sub(r"(#+)\s*", r"\1", text)
# return text
# 提取文本
def extract_chars(text: str, num: int) -> str:
char_num = 0
chars = ''
for char in text:
# 汉字算两个字符
if '\u4e00' <= char <= '\u9fff':
char_num += 2
else:
char_num += 1
chars += char
if char_num >= num:
break
return chars
@st.cache_data(max_entries=20, show_spinner=False)
def download_history(history: list):
md_text = ""
for msg in history:
if msg['role'] == 'user':
md_text += f'## {user_name}:\n{msg["content"]}\n'
elif msg['role'] == 'assistant':
md_text += f'## {gpt_name}:\n{msg["content"]}\n'
output = io.BytesIO()
output.write(md_text.encode('utf-8'))
output.seek(0)
return output
def filename_correction(filename: str) -> str:
pattern = r'[^\w\.-]'
filename = re.sub(pattern, '', filename)
return filename
def url_correction(text: str) -> str:
pattern = r'((?:http[s]?://|www\.)(?:[a-zA-Z0-9]|[$-_\~#!])+)'
text = re.sub(pattern, r' \g<1> ', text)
return text
# st的markdown会错误渲染英文引号加英文字符,例如 :abc
def colon_correction(text):
pattern = r':[a-zA-Z]'
if re.search(pattern, text):
text = text.replace(":", "&#58;")
pattern = r'`([^`]*)&#58;([^`]*)`|```([^`]*)&#58;([^`]*)```'
text = re.sub(pattern, lambda m: m.group(0).replace('&#58;', ':') if '&#58;' in m.group(0) else m.group(0),
text)
return text