Charles Chan commited on
Commit
1e21aa9
·
1 Parent(s): a054c10
Files changed (1) hide show
  1. app.py +43 -22
app.py CHANGED
@@ -7,34 +7,55 @@ from datasets import load_dataset
7
  from opencc import OpenCC
8
 
9
  # 使用 進擊的巨人 数据集
10
- try:
11
- converter = OpenCC('tw2s.json') # 'tw2s.json' 表示繁体中文到简体中文的转换
12
- dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
13
- answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
14
- except Exception as e:
15
- st.error(f"读取数据集失败:{e}")
16
- st.stop()
 
 
 
 
 
 
 
17
 
18
  # 构建向量数据库 (如果需要,仅构建一次)
19
- try:
20
- with st.spinner("正在读取数据库..."):
21
- embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
22
- db = FAISS.from_texts(answer_list, embeddings)
23
- st.success("数据库读取完成!")
24
- except Exception as e:
25
- st.error(f"向量数据库构建失败:{e}")
26
- st.stop()
 
 
 
 
27
 
28
  # 问答函数
 
 
 
 
 
 
29
  def answer_question(repo_id, temperature, max_length, question):
30
  # 初始化 Gemma 模型
31
- try:
32
- with st.spinner("正在初始化 Gemma 模型..."):
33
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
34
- st.success("Gemma 模型初始化完成!")
35
- except Exception as e:
36
- st.error(f"Gemma 模型加载失败:{e}")
37
- st.stop()
 
 
 
 
38
 
39
  # 获取答案
40
  try:
 
7
  from opencc import OpenCC
8
 
9
  # 使用 進擊的巨人 数据集
10
+ # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
11
+ if "dataset_loaded" not in st.session_state:
12
+ st.session_state.dataset_loaded = False
13
+ if not st.session_state.dataset_loaded:
14
+ try:
15
+ with st.spinner("正在读取数据库..."):
16
+ converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
17
+ dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
18
+ answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
19
+ st.success("数据库读取完成!")
20
+ except Exception as e:
21
+ st.error(f"读取数据集失败:{e}")
22
+ st.stop()
23
+ st.session_state.dataset_loaded = True
24
 
25
  # 构建向量数据库 (如果需要,仅构建一次)
26
+ if "vector_created" not in st.session_state:
27
+ st.session_state.vector_created = False
28
+ if not st.session_state.vector_created:
29
+ try:
30
+ with st.spinner("正在构建向量数据库..."):
31
+ embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
32
+ db = FAISS.from_texts(answer_list, embeddings)
33
+ st.success("向量数据库构建完成!")
34
+ except Exception as e:
35
+ st.error(f"向量数据库构建失败:{e}")
36
+ st.stop()
37
+ st.session_state.vector_created = True
38
 
39
  # 问答函数
40
+ if "repo_id" not in st.session_state:
41
+ st.session_state.repo_id = ''
42
+ if "temperature" not in st.session_state:
43
+ st.session_state.temperature = ''
44
+ if "max_length" not in st.session_state:
45
+ st.session_state.max_length = ''
46
  def answer_question(repo_id, temperature, max_length, question):
47
  # 初始化 Gemma 模型
48
+ if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
49
+ try:
50
+ with st.spinner("正在初始化 Gemma 模型..."):
51
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
52
+ st.success("Gemma 模型初始化完成!")
53
+ st.session_state.repo_id = repo_id
54
+ st.session_state.temperature = temperature
55
+ st.session_state.max_length = max_length
56
+ except Exception as e:
57
+ st.error(f"Gemma 模型加载失败:{e}")
58
+ st.stop()
59
 
60
  # 获取答案
61
  try: