Charles Chan
commited on
Commit
·
1e21aa9
1
Parent(s):
a054c10
coding
Browse files
app.py
CHANGED
@@ -7,34 +7,55 @@ from datasets import load_dataset
|
|
7 |
from opencc import OpenCC
|
8 |
|
9 |
# 使用 進擊的巨人 数据集
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# 构建向量数据库 (如果需要,仅构建一次)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
st.
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# 问答函数
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def answer_question(repo_id, temperature, max_length, question):
|
30 |
# 初始化 Gemma 模型
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# 获取答案
|
40 |
try:
|
|
|
7 |
from opencc import OpenCC
|
8 |
|
9 |
# 使用 進擊的巨人 数据集
|
10 |
+
# 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
|
11 |
+
if "dataset_loaded" not in st.session_state:
|
12 |
+
st.session_state.dataset_loaded = False
|
13 |
+
if not st.session_state.dataset_loaded:
|
14 |
+
try:
|
15 |
+
with st.spinner("正在读取数据库..."):
|
16 |
+
converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
|
17 |
+
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
|
18 |
+
answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
|
19 |
+
st.success("数据库读取完成!")
|
20 |
+
except Exception as e:
|
21 |
+
st.error(f"读取数据集失败:{e}")
|
22 |
+
st.stop()
|
23 |
+
st.session_state.dataset_loaded = True
|
24 |
|
25 |
# 构建向量数据库 (如果需要,仅构建一次)
|
26 |
+
if "vector_created" not in st.session_state:
|
27 |
+
st.session_state.vector_created = False
|
28 |
+
if not st.session_state.vector_created:
|
29 |
+
try:
|
30 |
+
with st.spinner("正在构建向量数据库..."):
|
31 |
+
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
|
32 |
+
db = FAISS.from_texts(answer_list, embeddings)
|
33 |
+
st.success("向量数据库构建完成!")
|
34 |
+
except Exception as e:
|
35 |
+
st.error(f"向量数据库构建失败:{e}")
|
36 |
+
st.stop()
|
37 |
+
st.session_state.vector_created = True
|
38 |
|
39 |
# 问答函数
|
40 |
+
if "repo_id" not in st.session_state:
|
41 |
+
st.session_state.repo_id = ''
|
42 |
+
if "temperature" not in st.session_state:
|
43 |
+
st.session_state.temperature = ''
|
44 |
+
if "max_length" not in st.session_state:
|
45 |
+
st.session_state.max_length = ''
|
46 |
def answer_question(repo_id, temperature, max_length, question):
|
47 |
# 初始化 Gemma 模型
|
48 |
+
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
49 |
+
try:
|
50 |
+
with st.spinner("正在初始化 Gemma 模型..."):
|
51 |
+
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
|
52 |
+
st.success("Gemma 模型初始化完成!")
|
53 |
+
st.session_state.repo_id = repo_id
|
54 |
+
st.session_state.temperature = temperature
|
55 |
+
st.session_state.max_length = max_length
|
56 |
+
except Exception as e:
|
57 |
+
st.error(f"Gemma 模型加载失败:{e}")
|
58 |
+
st.stop()
|
59 |
|
60 |
# 获取答案
|
61 |
try:
|