Charles Chan commited on
Commit
908d31d
·
1 Parent(s): cb8213b
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +65 -55
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: LocalQA
3
  emoji: 🐨
4
  colorFrom: gray
5
  colorTo: red
 
1
  ---
2
+ title: TitanQA
3
  emoji: 🐨
4
  colorFrom: gray
5
  colorTo: red
app.py CHANGED
@@ -5,15 +5,7 @@ from langchain_community.vectorstores import FAISS
5
  from datasets import load_dataset
6
  import random
7
 
8
- # 1. 准备知识库数据 (示例)
9
- knowledge_base = [
10
- "Gemma 是 Google 开发的大型语言模型。",
11
- "Gemma 具有强大的自然语言处理能力。",
12
- "Gemma 可以用于问答、对话、文本生成等任务。",
13
- "Gemma 基于 Transformer 架构。",
14
- "Gemma 支持多种语言。"
15
- ]
16
-
17
  try:
18
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
19
  answer_list = [example["Answer"] for example in dataset["train"]]
@@ -22,70 +14,88 @@ except Exception as e:
22
  st.error(f"读取数据集失败:{e}")
23
  st.stop()
24
 
25
-
26
- # 2. 构建向量数据库 (如果需要,仅构建一次)
27
  try:
28
- embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
29
- db = FAISS.from_texts(answer_list, embeddings)
 
 
30
  except Exception as e:
31
  st.error(f"向量数据库构建失败:{e}")
32
  st.stop()
33
 
34
- # 3. 问答函数
35
  def answer_question(repo_id, temperature, max_length, question):
36
- # 4. 初始化 Gemma 模型
37
  try:
38
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
 
39
  except Exception as e:
40
  st.error(f"Gemma 模型加载失败:{e}")
41
  st.stop()
42
 
43
- # 5. 获取答案
44
  try:
45
- question_embedding = embeddings.embed_query(question)
46
- question_embedding_str = " ".join(map(str, question_embedding))
47
- # print('question_embedding: ' + question_embedding_str)
48
- docs_and_scores = db.similarity_search_with_score(question_embedding_str)
 
49
 
50
- context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
51
- print('context: ' + context)
52
 
53
- prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
54
- print('prompt: ' + prompt)
55
 
56
- answer = llm.invoke(prompt)
57
- return answer
 
 
 
58
  except Exception as e:
59
  st.error(f"问答过程出错:{e}")
60
- return "An error occurred during the answering process."
 
 
 
61
 
62
- # 6. Streamlit 界面
63
- st.title("Gemma 知识库问答系统")
 
 
 
 
64
 
65
- gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
66
- temperature = st.number_input("temperature", value=1.0)
67
- max_length = st.number_input("max_length", value=1024)
68
- question = st.text_area("请输入问题", "Gemma 有哪些特点?")
 
 
 
 
 
 
69
 
70
- if st.button("随机"):
71
- dataset_size = len(dataset["train"])
72
- random_index = random.randint(0, dataset_size - 1)
73
- # 读取随机问题
74
- random_question = dataset["train"][random_index]["Question"]
75
- origin_answer = dataset["train"][random_index]["Answer"]
76
- st.write("随机问题:")
77
- st.write(random_question)
78
- st.write("原始答案:")
79
- st.write(origin_answer)
80
- answer = answer_question(gemma, float(temperature), int(max_length), random_question)
81
- st.write("生成答案:")
82
- st.write(answer)
83
 
84
- if st.button("提交"):
85
- if not question:
86
- st.warning("请输入问题!")
87
- else:
88
- with st.spinner("正在查询..."):
89
- answer = answer_question(gemma, float(temperature), int(max_length), question)
90
- st.write("答案:")
91
- st.write(answer)
 
 
 
 
5
  from datasets import load_dataset
6
  import random
7
 
8
+ # 使用 進擊的巨人 数据集
 
 
 
 
 
 
 
 
9
  try:
10
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
11
  answer_list = [example["Answer"] for example in dataset["train"]]
 
14
  st.error(f"读取数据集失败:{e}")
15
  st.stop()
16
 
17
+ # 构建向量数据库 (如果需要,仅构建一次)
 
18
  try:
19
+ with st.spinner("正在读取数据库..."):
20
+ embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
21
+ db = FAISS.from_texts(answer_list, embeddings)
22
+ st.success("数据库读取完成!")
23
  except Exception as e:
24
  st.error(f"向量数据库构建失败:{e}")
25
  st.stop()
26
 
27
+ # 问答函数
28
  def answer_question(repo_id, temperature, max_length, question):
29
+ # 初始化 Gemma 模型
30
  try:
31
+ with st.spinner("正在初始化 Gemma 模型..."):
32
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
33
  except Exception as e:
34
  st.error(f"Gemma 模型加载失败:{e}")
35
  st.stop()
36
 
37
+ # 获取答案
38
  try:
39
+ with st.spinner("正在筛选本地数据集..."):
40
+ question_embedding = embeddings.embed_query(question)
41
+ question_embedding_str = " ".join(map(str, question_embedding))
42
+ # print('question_embedding: ' + question_embedding_str)
43
+ docs_and_scores = db.similarity_search_with_score(question_embedding_str)
44
 
45
+ context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
46
+ print('context: ' + context)
47
 
48
+ prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
49
+ print('prompt: ' + prompt)
50
 
51
+ with st.spinner("正在生成答案..."):
52
+ answer = llm.invoke(prompt)
53
+ # 去掉 prompt 的内容
54
+ answer = answer.replace(prompt, "").strip()
55
+ return {"prompt": prompt, "answer": answer}
56
  except Exception as e:
57
  st.error(f"问答过程出错:{e}")
58
+ return {"prompt": "", "answer": "An error occurred during the answering process."}
59
+
60
+ # Streamlit 界面
61
+ st.title("進擊的巨人 知识库问答系统")
62
 
63
+ col1, col2 = st.columns(2)
64
+ with col1:
65
+ gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
66
+ with col2:
67
+ temperature = st.number_input("temperature", value=1.0)
68
+ max_length = st.number_input("max_length", value=1024)
69
 
70
+ col3, col4 = st.columns(2)
71
+ with col3:
72
+ if st.button("随机"):
73
+ dataset_size = len(dataset["train"])
74
+ random_index = random.randint(0, dataset_size - 1)
75
+ # 读取随机问题
76
+ random_question = dataset["train"][random_index]["Question"]
77
+ origin_answer = dataset["train"][random_index]["Answer"]
78
+ print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
79
+ print('origin_answer: ' + origin_answer)
80
 
81
+ st.write("随机问题:")
82
+ st.write(random_question)
83
+ st.write("原始答案:")
84
+ st.write(origin_answer)
85
+ result = answer_question(gemma, float(temperature), int(max_length), random_question)
86
+ print('prompt: ' + result["prompt"])
87
+ print('answer: ' + result["answer"])
88
+ st.write("生成答案:")
89
+ st.write(result["answer"])
 
 
 
 
90
 
91
+ with col4:
92
+ question = st.text_area("请输入问题", "Gemma 有哪些特点?")
93
+ if st.button("提交"):
94
+ if not question:
95
+ st.warning("请输入问题!")
96
+ else:
97
+ result = answer_question(gemma, float(temperature), int(max_length), question)
98
+ print('prompt: ' + result["prompt"])
99
+ print('answer: ' + result["answer"])
100
+ st.write("生成答案:")
101
+ st.write(result["answer"])